diff --git a/.gitattributes b/.gitattributes index e8d3bd1c066b770cf4dc99b1489fe462e6d142e6..19ea8bd262b6dd4e26b701ebaad44ee90594b664 100644 --- a/.gitattributes +++ b/.gitattributes @@ -10,3 +10,4 @@ build/torch210-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=l build/torch29-cxx11-cu126-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch29-cxx11-cu128-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch29-cxx11-cu129-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h new file mode 100644 index 0000000000000000000000000000000000000000..92b57aae26e22cc7a5859568882a9661f022c5a7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/reduce_split_k.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cutlass/device_kernel.h" +#include "cutlass/reduction/kernel/reduce_split_k.h" +#include "cutlass/cuda_host_adapter.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ReductionKernel_ +> +class ReduceSplitK { +public: + using ReductionKernel = ReductionKernel_; + + using Shape = typename ReductionKernel::Shape; + using ReductionOp = typename ReductionKernel::ReductionOp; + using OutputOp = typename ReductionKernel::OutputOp; + + using ElementWorkspace = typename ReductionKernel::ElementWorkspace; + using ElementAccumulator = typename ReductionKernel::ElementAccumulator; + using ElementOutput = typename ReductionKernel::ElementOutput; + + using WorkspaceTensorRef = typename ReductionKernel::WorkspaceTensorRef; + using OutputTensorRef = typename ReductionKernel::OutputTensorRef; + + using StrideIndex = typename ReductionKernel::StrideIndex; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + MatrixCoord problem_size{0,0}; + int partitions{1}; + size_t partition_stride{0}; + WorkspaceTensorRef workspace{}; + OutputTensorRef destination{}; + OutputTensorRef source{}; + typename OutputOp::Params output{}; + typename ReductionOp::Params reduction{}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + CUTLASS_HOST_DEVICE + Arguments( + MatrixCoord const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + MatrixCoord problem_size_, + int partitions_, + size_t partition_stride_, + WorkspaceTensorRef workspace_, + OutputTensorRef destination_, + OutputTensorRef source_, + typename OutputOp::Params output_ = typename OutputOp::Params(), + typename ReductionOp::Params reduction_ = typename ReductionOp::Params() + ): + problem_size(problem_size_), + partitions(partitions_), + partition_stride(partition_stride_), + workspace(workspace_), + destination(destination_), + source(source_), + output(output_), + reduction(reduction_) + { + + } + + }; + +private: + /// Kernel parameters object + typename ReductionKernel::Params params_; + +public: + /// Constructs Reduction SplitK + ReduceSplitK() { } + + /// Determines whether the ReduceSplitK can execute the given problem. + static Status can_implement(Arguments const &args) { + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + // needs no additional workspace + return 0; + } + + /// Initializes Reduction state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + // initialize the params structure from the arguments + params_ = typename ReductionKernel::Params( + args.problem_size, + args.partitions, + args.partition_stride, + args.workspace, + args.destination, + args.source, + args.output, + args.reduction + ); + + return Status::kSuccess; + + } + + /// Initializes Reduction kernel state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.workspace.reset(args.workspace.non_const_ref().data()); + params_.destination.reset(args.destination.non_const_ref().data()); + params_.source.reset(args.source.non_const_ref().data()); + params_.output = args.output; + params_.reduction = args.reduction; + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + + // + // Launch reduction kernel + // + dim3 block = ReductionKernel::block_shape(); + dim3 grid = ReductionKernel::grid_shape(params_.problem_size); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms_}; + cuda_adapter->launch( + grid, dim3(1,1,1), block, 0, stream, kernel_params, kernel_index); + } + } + else { + cutlass::arch::synclog_setup(); + Kernel<<< grid, block, 0, stream >>>(params_); + } + + cudaError_t result = cudaGetLastError(); + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + return run(stream, cuda_adapter, kernel_index); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream,cuda_adapter, kernel_index); + } + + return status; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..26a0249e9c259dbf2930832d2819188ec74bda60 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce.h @@ -0,0 +1,264 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/device/tensor_reduce_affine_strided.h" +#include "cutlass/reduction/device/tensor_reduce_affine_contiguous.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor reduction operator on specific CUTLASS layouts over exactly one index +template < + typename ElementOutput_, + typename ElementSource_, + typename Layout_, + typename ReductionOp_, + int VectorLength_ = 1, + typename ElementCompute_ = ElementOutput_ +> +struct TensorReduction { + + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using Layout = Layout_; + using ReductionOp = ReductionOp_; + static int const kVectorLength = VectorLength_; + using ElementCompute = ElementCompute_; + + using TensorCoord = typename Layout::TensorCoord; + + /// Reduction operator + using ReductionDeviceStridedOperator = TensorReductionAffineStrided< + 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute + >; + + using ReductionDeviceContiguousOperator = TensorReductionAffineContiguous< + 4, 3, ElementOutput, ElementSource, ReductionOp, kVectorLength, ElementCompute + >; + + // + // Data members + // + + ReductionDeviceStridedOperator reduction_strided; + ReductionDeviceContiguousOperator reduction_contiguous; + int reduction_index; + + // + // Methods + // + + /// + TensorReduction( + TensorCoord extent, + int reduction_index_ + ): + reduction_index(reduction_index_) { + + Coord<4> extent_affine; + + switch (reduction_index) { + case 0: + extent_affine[0] = extent[1]; + extent_affine[1] = extent[2]; + extent_affine[2] = extent[0]; + extent_affine[3] = extent[3]; + break; + case 1: + extent_affine[0] = extent[0]; + extent_affine[1] = extent[2]; + extent_affine[2] = extent[1]; + extent_affine[3] = extent[3]; + break; + case 2: + extent_affine[0] = extent[0]; + extent_affine[1] = extent[1]; + extent_affine[2] = extent[2]; + extent_affine[3] = extent[3]; + break; + case 3: + extent_affine[0] = extent[0]; + extent_affine[1] = extent[1]; + extent_affine[2] = extent[2]; + extent_affine[3] = extent[3]; + break; + default: break; + } + + if (reduction_index == 3) { + reduction_contiguous = ReductionDeviceContiguousOperator(extent_affine); + } + else { + reduction_strided = ReductionDeviceStridedOperator(extent_affine); + } + } + + /// Simple check to verify the object is initialized correctly + bool good() const { + if (reduction_index == 3) { + return reduction_contiguous.good(); + } + return reduction_strided.good(); + } + + /// Size of one workspace + int64_t workspace_stride() const { + if (reduction_index == 3) { + return reduction_contiguous.workspace_stride(); + } + else { + return reduction_strided.workspace_stride(); + } + } + + /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs + int64_t workspace_size() const { + if (reduction_index == 3) { + return reduction_contiguous.workspace_size(); + } + else { + return reduction_strided.workspace_size(); + } + } + + /// Helper to use overloaded function call operator + Status reduce( + TensorRef dst_ref, + TensorRef src_ref, + void *device_workspace_ptr = nullptr, + ElementCompute reduction_identity = ElementCompute(), + ReductionOp reduction_op = ReductionOp(), + cudaStream_t stream = nullptr) { + + int64_t src_stride[3]; + int64_t dst_stride[3]; + + switch (reduction_index) { + case 0: + src_stride[0] = src_ref.stride()[1]; + src_stride[1] = src_ref.stride()[0]; + src_stride[2] = src_ref.stride()[2]; + dst_stride[0] = dst_ref.stride()[1]; + dst_stride[1] = dst_ref.stride()[0]; + break; + case 1: + src_stride[0] = src_ref.stride()[2]; + src_stride[1] = src_ref.stride()[0]; + src_stride[2] = src_ref.stride()[1]; + dst_stride[0] = dst_ref.stride()[2]; + dst_stride[1] = dst_ref.stride()[0]; + break; + case 2: + src_stride[0] = src_ref.stride()[2]; + src_stride[1] = src_ref.stride()[1]; + src_stride[2] = src_ref.stride()[0]; + dst_stride[0] = dst_ref.stride()[2]; + dst_stride[1] = dst_ref.stride()[1]; + break; + case 3: + src_stride[0] = src_ref.stride()[2]; + src_stride[1] = src_ref.stride()[1]; + src_stride[2] = src_ref.stride()[0]; + + dst_stride[0] = dst_ref.stride()[2]; + dst_stride[1] = dst_ref.stride()[1]; + dst_stride[2] = dst_ref.stride()[0]; + + default: break; + } + + if (reduction_index == 3) { + return reduction_contiguous( + dst_ref.data(), + dst_stride, + src_ref.data(), + src_stride, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } + else { + return reduction_strided( + dst_ref.data(), + dst_stride, + src_ref.data(), + src_stride, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } + } + + Status operator()( + TensorRef dst_ref, + TensorRef src_ref, + void *device_workspace_ptr = nullptr, + ElementCompute reduction_identity = ElementCompute(), + ReductionOp reduction_op = ReductionOp(), + cudaStream_t stream = nullptr) { + + return reduce( + dst_ref, + src_ref, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h new file mode 100644 index 0000000000000000000000000000000000000000..c00c368165902bdda08f6316a07be19668dc0fb9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_contiguous.h @@ -0,0 +1,374 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor reduction operator on layouts which are affine +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (e.g. ND => 2) + typename ElementOutput_, + typename ElementSource_, + typename ReductionOp_, + int VectorLength = 1, + typename ElementCompute_ = ElementOutput_, + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineContiguous { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ReductionOp = ReductionOp_; + using ElementCompute = ElementCompute_; + + // + // Data members + // + + /// Internal status field + Status status; + + /// Extent of tensor in source layout + Coord extent; + + /// Number of points in the outer index space + int64_t outer_count; + + /// Number of elements in the inner index space + int64_t inner_count; + + /// Number of workspaces needed + int workspace_count; + + /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) + dim3 grid_shape; + + /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) + dim3 threadblock_shape; + + /// CUDA grid shape for the final reduction step if needed + dim3 grid_final; + + /// CUDA threadblock shape for the final reduction step if needed + dim3 threadblock_final; + +private: + // + // Methods + // + + /// Helper to reshape 'count' such that it is less than 2 x 'ext' + static int reshape_pow2(int ext, int count) { + if (ext > count) { + return 1; + } + int x = 1; + for (; count >= ext * 2; ) { + count >>= 1; + x <<= 1; + } + return x; + } + +public: + + /// Default ctor + TensorReductionAffineContiguous(): + status(Status::kErrorInvalidProblem), + extent(), + outer_count(0), + inner_count(0), + workspace_count(0), + grid_shape(0, 0, 0), + threadblock_shape(0, 0, 0) { } + + /// Constructor + TensorReductionAffineContiguous( + Coord extent_, + int target_threadblock_count = 128 + ): + status(Status::kSuccess), + extent(extent_), + outer_count(0), + inner_count(0), + workspace_count(0) { + + // + // Plan the parallel mapping strategy. + // + + outer_count = 1; + inner_count = 1; + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank; ++p) { + outer_count *= extent[p]; + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= extent[kReducedRank + p]; + } + + int cta_count_x = 1; + int cta_count_y = 1; + int cta_count_z = 1; + + int cta_threads_x = kThreads; + int cta_threads_y = 1; + int cta_threads_z = 1; + + // Determine CTA shape + int64_t inner_vector_count = inner_count / kVectorLength; + + // Priority 1. Assign threadblocks to outer indices if possible + if (outer_count > target_threadblock_count) { + cta_count_x = 1; + cta_count_y = target_threadblock_count; + cta_count_z = 1; + } + else { + + cta_count_y = int(outer_count); + int remaining_ctas = target_threadblock_count / cta_count_y; + + // Priority 2. Assign inner dimensions to one CTA + if (inner_vector_count > cta_threads_x) { + int64_t cta_z_bound = inner_vector_count / cta_threads_x; + if (cta_z_bound > remaining_ctas) { + cta_count_z = remaining_ctas; + } + else { + cta_count_z = int(cta_z_bound); + } + } + else { + cta_threads_x = reshape_pow2(int(inner_vector_count), cta_threads_x); + cta_count_z = 1; + } + } + + grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); + threadblock_shape = dim3(cta_threads_x, cta_threads_y, cta_threads_z); + + workspace_count = (cta_count_z > 1 ? cta_count_z : 0); + + // Determine shape of final reduction kernel if needed + if (workspace_count) { + + int final_threads = kThreads; + int final_ctas = 1; + + if (outer_count > kThreads) { + final_ctas = int(outer_count + kThreads - 1) / kThreads; + } + else { + final_threads = int(outer_count); + } + + grid_final = dim3(final_ctas, 1, 1); + threadblock_final = dim3(final_threads, 1, 1); + } + else { + grid_final = dim3(0, 0, 0); + threadblock_final = dim3(0, 0, 0); + } + } + + /// Simple check to verify the object is initialized correctly + bool good() const { + return status == Status::kSuccess; + } + + /// Size (in bytes) of workspace elements which are densely packed together + int64_t workspace_stride() const { + + // Error condition + if (!good()) { + return 0; + } + + return outer_count * sizeof_bits::value / 8; + } + + /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs + int64_t workspace_size() const { + + // Error condition + if (!good()) { + return 0; + } + + // No reduction across CTAs + if (grid_shape.z == 1) { + return 0; + } + + return workspace_stride() * grid_shape.z; + } + + /// Performs a reduction + Status reduce( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + // Initial status check + if (!good()) { + return status; + } + + // Guard against null workspace + if (workspace_count > 1 && device_workspace_ptr == nullptr) { + return Status::kErrorWorkspaceNull; + } + + // Define reduction kernel + using ReductionKernel = kernel::TensorReductionAffineContiguous< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using FinalReductionKernel = kernel::TensorReductionAffineContiguousFinal< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using Params = typename ReductionKernel::Params; + + // Construct the parameters + Params params( + extent, + dst_ptr, + dst_stride, + src_ptr, + src_stride, + static_cast(device_workspace_ptr), + workspace_stride(), + workspace_count, + reduction_op, + reduction_identity); + + // Shared memory size + int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); + + // Launch the kernel + cutlass::arch::synclog_setup(); + Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + + // Final reduction kernel + if (workspace_count) { + Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); + } + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + + return status; + } + + /// Helper to use overloaded function call operator + Status operator()( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Pointer to device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduction identity element + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + return reduce(dst_ptr, dst_stride, src_ptr, src_stride, device_workspace_ptr, reduction_identity, reduction_op, stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h new file mode 100644 index 0000000000000000000000000000000000000000..c85d6dcbf13ba17a82b252124313c58f901e55f5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/device/tensor_reduce_affine_strided.h @@ -0,0 +1,362 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/kernel/tensor_reduce_affine_strided.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor reduction operator on layouts which are affine +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput_, + typename ElementSource_, + typename ReductionOp_, + int VectorLength = 1, + typename ElementCompute_ = ElementOutput_, + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineStrided { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ReductionOp = ReductionOp_; + using ElementCompute = ElementCompute_; + + // + // Data members + // + + /// Internal status field + Status status; + + /// Extent of tensor in source layout + Coord extent; + + /// Number of points in the outer index space + int64_t outer_count; + + /// Number of elements in the inner index space + int64_t inner_count; + + /// Number of workspaces needed + int workspace_count; + + /// CUDA Grid shape (.x => contiguous, .y => outer, .z => inner) + dim3 grid_shape; + + /// CUDA Threadblock shape (.x => contiguous, .y => outer, .z => inner) + dim3 threadblock_shape; + + /// CUDA grid shape for the final reduction step if needed + dim3 grid_final; + + /// CUDA threadblock shape for the final reduction step if needed + dim3 threadblock_final; + +private: + // + // Methods + // + + /// Helper to reshape 'count' such that it is less than 2 x 'ext' + static int reshape_pow2(int ext, int count) { + if (ext > count) { + return 1; + } + int x = 1; + for (; count >= ext * 2; ) { + count >>= 1; + x <<= 1; + } + return x; + } + +public: + + /// Default ctor + TensorReductionAffineStrided(): + status(Status::kErrorInvalidProblem), + extent(), + outer_count(0), + inner_count(0), + workspace_count(0), + grid_shape(0, 0, 0), + threadblock_shape(0, 0, 0) { } + + /// Constructor + TensorReductionAffineStrided( + Coord extent_, + int target_threadblock_count = 128 + ): + status(Status::kSuccess), + extent(extent_), + outer_count(0), + inner_count(0), + workspace_count(0) { + + // + // Plan the parallel mapping strategy. + // + + outer_count = 1; + inner_count = 1; + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank - 1; ++p) { + outer_count *= extent[p]; + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= extent[kReducedRank + p - 1]; + } + + // Compute plan for the reduction + int extent_c = extent[kRank - 1]; + int vectors_c = (extent_c -1 + kVectorLength) / kVectorLength; + + // Determine CTA shape + int cta_width = kThreads * kVectorLength; + int cta_ways = reshape_pow2(extent_c, cta_width); + int cta_threads_x = kThreads / cta_ways; + + threadblock_shape = dim3(cta_threads_x, 1, std::min(cta_ways, 64)); + + // This leads to an error. + if (threadblock_shape.z > 1) { + if (threadblock_shape.y != 1) { + status = Status::kErrorInternal; + return; + } + } + + // Determine grid shape + int cta_count_x = (vectors_c + cta_threads_x - 1) / cta_threads_x; + int cta_count_y = std::max(1, target_threadblock_count / cta_count_x); + + // Limit the number of CTAs assigned to outer dimension + if (int64_t(cta_count_y * threadblock_shape.y) > outer_count) { + cta_count_y = int(outer_count + threadblock_shape.y - 1) / threadblock_shape.y; + } + + // Limit the number of CTAs assigned to inner dimension + int cta_count_z = std::max(1, target_threadblock_count / cta_count_y); + if (int64_t(cta_count_z * threadblock_shape.z) > inner_count) { + cta_count_z = int(inner_count + threadblock_shape.z - 1) / threadblock_shape.z; + } + + grid_shape = dim3(cta_count_x, cta_count_y, cta_count_z); + workspace_count = (cta_count_z > 1 ? cta_count_z : 0); + + // Determine shape of final reduction kernel if needed + grid_final = dim3(cta_count_x, int(outer_count)); + threadblock_final = dim3(cta_threads_x, 1, 1); + } + + /// Simple check to verify the object is initialized correctly + bool good() const { + return status == Status::kSuccess; + } + + /// Size of one CTA's workspace + int64_t workspace_stride() const { + + // Error condition + if (!good()) { + return 0; + } + + int vector_size_bytes = kVectorLength * sizeof_bits::value / 8; + + return extent[kRank - 1] * vector_size_bytes; + } + + /// Returns the size (in bytes) of a temporary workspace needed for reduction across CTAs + int64_t workspace_size() const { + + // Error condition + if (!good()) { + return 0; + } + + // No reduction across CTAs + if (grid_shape.z == 1) { + return 0; + } + + return workspace_stride() * outer_count * grid_shape.z; + } + + /// Performs a reduction + Status reduce( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + // Initial status check + if (!good()) { + return status; + } + + // Guard against null workspace + if (workspace_count > 1 && device_workspace_ptr == nullptr) { + return Status::kErrorWorkspaceNull; + } + + // Define reduction kernel + using ReductionKernel = kernel::TensorReductionAffineStrided< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using FinalReductionKernel = kernel::TensorReductionAffineStridedFinal< + kRank, + kReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + kVectorLength, + ElementCompute, + kThreads>; + + using Params = typename ReductionKernel::Params; + + // Construct the parameters + Params params( + extent, + dst_ptr, + dst_stride, + src_ptr, + src_stride, + static_cast(device_workspace_ptr), + workspace_stride(), + workspace_count, + reduction_op, + reduction_identity); + + // Shared memory size + int shared_mem_bytes = sizeof(typename ReductionKernel::SharedStorage); + + // Launch the kernel + cutlass::arch::synclog_setup(); + Kernel<<< grid_shape, threadblock_shape, shared_mem_bytes, stream >>>(params); + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + + // Final reduction kernel + if (workspace_count) { + + Kernel<<< grid_final, threadblock_final, 0, stream >>>(params); + + // Check error condition + if (cudaPeekAtLastError() == cudaSuccess) { + status = Status::kSuccess; + } + else { + status = Status::kErrorInternal; + } + } + + return status; + } + + /// Helper to use overloaded function call operator + Status operator()( + ElementOutput *dst_ptr, ///< Pointer to destination tensor + int64_t dst_stride[], ///< Stride vector (of length kReducedRank - 1) + ElementSource const *src_ptr, ///< Pointer to source tensor + int64_t src_stride[], ///< Stride vector (of length kRank - 1) + void *device_workspace_ptr = nullptr, ///< Pointer to device workspace + ElementCompute reduction_identity = ElementCompute(), ///< Reduciton identity + ReductionOp reduction_op = ReductionOp(), ///< Reduction operator + cudaStream_t stream = nullptr) { ///< CUDA Stream into which all kernels are launched + + return reduce( + dst_ptr, + dst_stride, + src_ptr, + src_stride, + device_workspace_ptr, + reduction_identity, + reduction_op, + stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h new file mode 100644 index 0000000000000000000000000000000000000000..3d39dc751c4bdef328398c5a94e5462136728f6a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_softmax_final.h @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a final reduction for softmax +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +template < + typename ElementNorm_, + typename ElementSum_, + typename ElementSoftmaxCompute_, + typename ThreadblockShape_, + bool GroupedProblem = false +> +class ApplySoftmaxFinalReduction { +public: + + using ElementNorm = ElementNorm_; + using ElementSum = ElementSum_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; + using ThreadblockShape = ThreadblockShape_; + static const bool isGroupedProblem = GroupedProblem; + + // + // Arguments + // + + struct Arguments { + + cutlass::gemm::GemmCoord* problem_sizes{nullptr}; + cutlass::gemm::GemmCoord problem_size{}; + ElementNorm* block_Norm{nullptr}; + ElementSum* block_Sum{nullptr}; + int64_t* offset_Norm_Device{nullptr}; + int64_t* offset_Sum_Device{nullptr}; + int64_t batch_stride_Max{0}; + int64_t batch_stride_Sum{0}; + + // + // Methods + // + Arguments() { } + + // Non-grouped constructor without batching + Arguments( + cutlass::gemm::GemmCoord problem_size, + ElementNorm* block_Norm, + ElementSum* block_Sum + ): + problem_size(problem_size), + block_Norm(block_Norm), + block_Sum(block_Sum), + problem_sizes(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr), + batch_stride_Max(0), + batch_stride_Sum(0) + { + + } + + // Non-grouped constructor with batching + Arguments( + cutlass::gemm::GemmCoord problem_size, + ElementNorm* block_Norm, + ElementSum* block_Sum, + int64_t batch_stride_Max, + int64_t batch_stride_Sum + ): + problem_size(problem_size), + block_Norm(block_Norm), + block_Sum(block_Sum), + batch_stride_Max(batch_stride_Max), + batch_stride_Sum(batch_stride_Sum), + problem_sizes(nullptr), + offset_Norm_Device(nullptr), + offset_Sum_Device(nullptr) + { + + } + + + // Grouped constructor + Arguments( + cutlass::gemm::GemmCoord *problem_sizes, + ElementNorm* block_Norm, + ElementSum* block_Sum, + int64_t* offset_Norm_Device, + int64_t* offset_Sum_Device + ): + problem_sizes(problem_sizes), + problem_size(cutlass::gemm::GemmCoord(0, 0, 0)), + block_Norm(block_Norm), + block_Sum(block_Sum), + offset_Norm_Device(offset_Norm_Device), + offset_Sum_Device(offset_Sum_Device) + { + + } + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + ApplySoftmaxFinalReduction() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + apply(params, shared_storage); + } + +private: + + /// Full reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + int tid = threadIdx.x; + int bid = blockIdx.x; + int bdim = blockDim.x; + + int block_batch = blockIdx.z; + + // defining three vars for a general reduction module + cutlass::gemm::GemmCoord problem_size = isGroupedProblem ? params.args.problem_sizes[bid] : params.args.problem_size; + int m_dim_in_loop = isGroupedProblem ? problem_size.m() : tid + bdim; + int access_offset = isGroupedProblem ? 0 : bid * bdim; + + if (!isGroupedProblem && access_offset + tid >= problem_size.m()) return; + + ElementNorm *curr_ptr_Max = isGroupedProblem ? \ + params.args.block_Norm + params.args.offset_Norm_Device[bid] : \ + params.args.block_Norm + block_batch * params.args.batch_stride_Max; + ElementSum *curr_ptr_Sum = isGroupedProblem ? \ + params.args.block_Sum + params.args.offset_Sum_Device[bid] : \ + params.args.block_Sum + block_batch * params.args.batch_stride_Sum; + + int threadblock_num = (problem_size.n() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; + + using ConvertSumOutput = cutlass::NumericConverter; + using ConvertNormOutput = cutlass::NumericConverter; + + using ConvertSum = cutlass::NumericConverter; + using ConvertNorm = cutlass::NumericConverter; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + + ConvertSumOutput convert_sum_output; + ConvertNormOutput convert_norm_output; + + uint32_t float_max_bits = 0xff7fffff; + float min_float = reinterpret_cast(float_max_bits); + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = tid; idx_m < m_dim_in_loop; idx_m += bdim) { + ElementNorm *access_n = curr_ptr_Max + idx_m + access_offset; + ElementSum *access_s = curr_ptr_Sum + idx_m + access_offset; + ElementNorm *access_n_bak = access_n; + ElementSum *access_s_bak = access_s; + ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); + ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); + ElementNorm fetch_n; + ElementSum fetch_s; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + max_val = cutlass::fast_max(max_val, convert_norm(fetch_n)); + access_n += problem_size.m(); + } + + access_n = access_n_bak; + + CUTLASS_PRAGMA_UNROLL + for (int idx_n = 0; idx_n < threadblock_num; idx_n++) { + cutlass::arch::global_load(fetch_n, access_n, true); + cutlass::arch::global_load(fetch_s, access_s, true); + sum_val += convert_sum(fetch_s) * cutlass::fast_exp(convert_norm(fetch_n) - max_val); + access_n += problem_size.m(); + access_s += problem_size.m(); + } + + ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; + + access_n = access_n_bak; + access_s = access_s_bak; + + access_n[0] = convert_norm_output(max_val); + access_s[0] = convert_sum_output(inv_sum); + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h new file mode 100644 index 0000000000000000000000000000000000000000..f6d26666957a58321c579b191ec06c84503e8ca2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/reduce_split_k.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/layout/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape_, ///< shape of CTA (concept: MatrixShape) + typename OutputOp_ , ///< output operator (concept: epilogue::thread operator) + typename ReductionOp_, ///< reduction operator (concept: ReductionOperator) + int PartitionsPerStage = 4 ///< number of partitions to issue +> +class ReduceSplitK { +public: + + using Shape = Shape_; + using ReductionOp = ReductionOp_; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = OutputOp::kCount; + static int const kPartitionsPerStage = PartitionsPerStage; + + using ElementWorkspace = typename ReductionOp::Element; + using ElementAccumulator = typename ReductionOp::ElementAccumulator; + using ElementOutput = typename OutputOp::ElementOutput; + + using WorkspaceTensorRef = TensorRef; + using OutputTensorRef = TensorRef; + using StrideIndex = typename WorkspaceTensorRef::Layout::Stride::Index; + + using FragmentWorkspace = AlignedArray; + using FragmentAccumulator = Array; + using FragmentOutput = AlignedArray; + + // + // Types + // + + /// Params structure + struct Params { + + MatrixCoord problem_size; + int partitions; + size_t partition_stride; + WorkspaceTensorRef workspace; + OutputTensorRef destination; + OutputTensorRef source; + typename OutputOp::Params output; + typename ReductionOp::Params reduction; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params( + MatrixCoord problem_size_, + int partitions_, + size_t partition_stride_, + WorkspaceTensorRef workspace_, + OutputTensorRef destination_, + OutputTensorRef source_, + typename OutputOp::Params output_ = typename OutputOp::Params(), + typename ReductionOp::Params reduction_ = typename ReductionOp::Params() + ): + problem_size(problem_size_), + partitions(partitions_), + partition_stride(sizeof(FragmentWorkspace) * partition_stride_ / kElementsPerAccess), + workspace(workspace_), + destination(destination_), + source(source_), + output(output_), + reduction(reduction_) { + + } + }; + + struct SharedStorage { }; + + +public: + + /// Computes the grid size given a chosen threadblock shape + CUTLASS_HOST_DEVICE + static dim3 grid_shape( + cutlass::MatrixCoord problem_size) { + + return dim3( + (problem_size.row() + Shape::kRow - 1) / Shape::kRow, + (problem_size.column() + Shape::kColumn - 1) / Shape::kColumn); + } + + /// Determines the threadblock shape + CUTLASS_HOST_DEVICE + static dim3 block_shape() { + return dim3(Shape::kColumn / kElementsPerAccess, Shape::kRow); + } + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &storage) { + + // Determine CTA position + MatrixCoord thread_offset( + MatrixCoord::Index(int(blockIdx.x) * Shape::kRow + threadIdx.y), + MatrixCoord::Index(int(blockIdx.y) * Shape::kColumn + threadIdx.x * kElementsPerAccess) + ); + + // One guard conditional + if (!(thread_offset.row() < params.problem_size.row() && + thread_offset.column() < params.problem_size.column())) { + + return; + } + + + ReductionOp reduction_op(params.reduction); + + FragmentAccumulator accumulator; + + accumulator.clear(); + + // + // Load the first slice + // + + char const *workspace_ptr = + reinterpret_cast( + params.workspace.data() + params.workspace.offset(thread_offset)); + + FragmentWorkspace workspace_frag[kPartitionsPerStage]; + + // + // Construct the output operator + // + + OutputOp output_op(params.output); + + // + // Load and accumulate with a simple batched loading sequence. + // + + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < params.partitions; k += kPartitionsPerStage) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPartitionsPerStage; ++i) { + if (k + i < params.partitions) { + workspace_frag[i] = *reinterpret_cast(workspace_ptr); + workspace_ptr += params.partition_stride; + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPartitionsPerStage; ++i) { + if (k + i < params.partitions) { + accumulator = reduction_op(accumulator, workspace_frag[i]); + } + } + } + + // + // Conditionally load the source + // + + FragmentOutput source_frag; + + source_frag.clear(); + + FragmentOutput const *source_ptr = reinterpret_cast( + params.source.data() + params.source.offset(thread_offset)); + + if (output_op.is_source_needed()) { + reinterpret_cast(source_frag) = *source_ptr; + } + + // + // Compute the output + // + + typename OutputOp::FragmentOutput output_frag = output_op(accumulator, source_frag); + + // + // Store + // + + FragmentOutput *dest_ptr = reinterpret_cast( + params.destination.data() + params.destination.offset(thread_offset)); + + *dest_ptr = reinterpret_cast(output_frag); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h new file mode 100644 index 0000000000000000000000000000000000000000..914bbddda9227d1f1772d8e8171b06280b7a5f61 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_contiguous.h @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/thread/reduction_operators.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (i.e. number of outer ranks) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineContiguousParams { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + Coord extent; /// Extent of source tensor + FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank + int64_t dst_stride[kReducedRank]; /// stride (units of bytes) - I, J + int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K + int64_t workspace_stride; /// stride (units of bytes) between workspace + int workspace_count; /// number of workspaces + + uint64_t inner_count; /// Number of elements in reduced index space + uint64_t outer_count; /// Number of elements in outer index space + + ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank + ElementSource const * source; /// Pointer to source pointer of rank kRank + ReductionOp reduction_op; /// Reduction operator + ElementCompute reduction_identity; /// Identity element used by reduction operator + ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorReductionAffineContiguousParams() { + + } + + /// Ctor + TensorReductionAffineContiguousParams( + Coord extent_, ///< Extent of source tensor + ElementOutput * dst_ptr_, ///< Output tensor data + int64_t dst_stride_[], ///< Stride (units of elements) + ElementSource const * src_ptr_, ///< Source tensor data + int64_t src_stride_[], ///< Stride (units of elements) + ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions + int64_t workspace_stride_, ///< Stride between workspaces + int workspace_count_, ///< Number of workspaces + ReductionOp reduction_op_, ///< Reduction operator + ElementCompute reduction_identity_ = ElementCompute() ///< Identity element used by reduction operator + ): + extent(extent_), + inner_count(1), + outer_count(1), + destination(dst_ptr_), + source(src_ptr_), + device_workspace(device_workspace_), + workspace_stride(workspace_stride_), + workspace_count(workspace_count_), + reduction_op(reduction_op_), + reduction_identity(reduction_identity_) { + + // Initialize divisors for fast div-mod + for (int p = 1; p < kRank; ++p) { + divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); + } + + int input_size_bits = sizeof_bits::value; + int output_size_bits = sizeof_bits::value; + + // Compute strides in units of bytes + for (int p = 0; p < kReducedRank; ++p) { + dst_stride[p] = dst_stride_[p] * output_size_bits / 8; + } + + for (int p = 0; p < kRank - 1; ++p) { + src_stride[p] = src_stride_[p] * input_size_bits / 8; + } + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank; ++p) { + outer_count *= uint64_t(extent[p]); + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= uint64_t(extent[kRank - 1 - p]); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to reduce a tensor with affine layout over a set of ranks *INCLUDING* the contiguous +/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineContiguous { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory allocation used for reduction within the CTA + struct SharedStorage { + Array workspace; + }; + + /// Parameters structure + using Params = TensorReductionAffineContiguousParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_inner_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into a coordinate of rank + coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kRank - kInnerRank]); + + // Compute an offset using the souce stride + src_offset = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kInnerRank - 1; ++i) { + src_offset += coord[i] * params.src_stride[kReducedRank + i]; + } + src_offset += coord[kInnerRank - 1] * sizeof_bits::value / 8; + } + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate of rank + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offsets using destination and source strides + dst_offset = 0; + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + src_offset += params.src_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices yielding a single element + CUTLASS_DEVICE + ElementCompute reduce_indices_( + Params const ¶ms, + ElementCompute *threadblock_workspace, + char const *src_byte_ptr, + int coord_c) { + + NumericArrayConverter convert_source; + ReductionOp reduction_op(params.reduction_op); + + // + // Early exit or initialize to identity element + // + if (!params.inner_count) { + return params.reduction_identity; + } + + ComputeFragment accumulator; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(accumulator.size()); ++i) { + accumulator[i] = params.reduction_identity; + } + + // Compute the coordinate of the first access + int64_t src_byte_offset = 0; + Coord coord; + + uint64_t linear_idx = (threadIdx.x + blockDim.x * threadIdx.z + blockDim.x * blockIdx.z * blockDim.z) * kVectorLength; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + + // Load the first vector + SourceFragment source_fragment[kBatchSize]; + + bool not_done = true; + + // Iterate over vectors in a linearized reduction index space + while (not_done) { + + bool guards[kBatchSize]; + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + + if (linear_idx < params.inner_count) { + source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); + guards[b] = true; + } + else { + guards[b] = false; + not_done = false; + } + + linear_idx += (blockDim.z * gridDim.z * blockDim.x) * kVectorLength; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + } + + // Perform a batch of reduction operations + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (guards[b]) { + auto cvt = convert_source(source_fragment[b]); + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + cvt); + } + } + }; + + // + // Reduction of vectors to scalar + // + + ElementCompute reduced_accumulator = accumulator[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kVectorLength; ++i) { + reduced_accumulator = reduction_op(reduced_accumulator, accumulator[i]); + } + + // + // Reduction within CTA across threadIdx.xz => threadIdx{.x = 0, .z = 0} + // + // This re-arranges data so threadIdx.y is effectively a row index and threadIdx.xz is a column + // + + int thread_count = blockDim.x * blockDim.z; + int thread_j = threadIdx.x + blockDim.x * threadIdx.z; + int thread_i = threadIdx.y; + + ElementCompute *frag_ptr = reinterpret_cast(threadblock_workspace) + thread_i * thread_count; + + frag_ptr[thread_j] = reduced_accumulator; + + // + // Reduce + // + CUTLASS_PRAGMA_NO_UNROLL + while (thread_count > 1) { + thread_count /= 2; + + __syncthreads(); + + if (thread_j < thread_count) { + ElementCompute other = frag_ptr[thread_j + thread_count]; + + reduced_accumulator = reduction_op(reduced_accumulator, other); + + frag_ptr[thread_j] = reduced_accumulator; + } + + __syncthreads(); + } + + + return reduced_accumulator; + } + +public: + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char const * src_byte_ptr = reinterpret_cast(params.source); + char * dst_byte_ptr = nullptr; + + // If performing a reduction across CTAs, redirect output to device workspace + if (gridDim.z == 1) { + dst_byte_ptr = reinterpret_cast(params.destination); + } + else { + dst_byte_ptr = reinterpret_cast(params.device_workspace); + } + + uint64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + int64_t src_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + if (gridDim.z == 1) { + + /// Complete the reduction with no workspace + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset, + coord_c); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0 && threadIdx.x == 0) { + + // Convert to output type and store + NumericConverter convert_output; + ElementOutput cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = cvt; + } + + __syncthreads(); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while + } + else { + + /// Complete the reduction with workspace + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset, + coord_c); + + int64_t byte_offset = + blockIdx.z * params.workspace_stride + idx_linear * sizeof_bits::value / 8; + + // Store the result for final reduction + if (threadIdx.z == 0 && threadIdx.x == 0) { + *reinterpret_cast(dst_byte_ptr + byte_offset) = result; + } + + __syncthreads(); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + } // while + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to perform final reduction +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineContiguousFinal { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + /// Shared memory + struct SharedStorage { }; + + /// Parameters structure + using Params = TensorReductionAffineContiguousParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate of rank + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offsets using destination and source strides + dst_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ElementCompute reduce_indices_( + Params const ¶ms, + ElementCompute const *device_workspace) { + + ReductionOp reduction_op(params.reduction_op); + char const *src_byte_ptr = reinterpret_cast(device_workspace); + + // Accumulated output + ElementCompute accumulator = params.reduction_identity; + + for (int iter = 0; iter < params.workspace_count; ++iter) { + ElementCompute workspace_item = *reinterpret_cast(src_byte_ptr); + + accumulator = reduction_op(accumulator, workspace_item); + + src_byte_ptr += params.workspace_stride; + } + + return accumulator; + } + +public: + + // + // Methods + // + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + uint64_t idx_linear = blockIdx.x * blockDim.x + threadIdx.x; + + char * dst_byte_ptr = reinterpret_cast(params.destination); + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + + /// Complete the reduction + while (idx_linear < params.outer_count) { + + ElementCompute result = reduce_indices_(params, params.device_workspace + idx_linear); + + // Convert to output type and store + NumericConverter convert_output; + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = convert_output(result); + + // Update indices and pointers + idx_linear += gridDim.x * blockDim.x; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h new file mode 100644 index 0000000000000000000000000000000000000000..0538184f3886b53207cc28a46a9fb8b04d3e8c5e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/kernel/tensor_reduce_affine_strided.h @@ -0,0 +1,641 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over one or more ranks of an affine tensor +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/reduction/thread/reduction_operators.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Parameters structure +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +struct TensorReductionAffineStridedParams { + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + + Coord extent; /// Extent of source tensor + FastDivmodU64 divmod[kRank - 1]; /// FastDivmod by each strided rank + int64_t dst_stride[kReducedRank - 1]; /// stride (units of bytes) - I, J + int64_t src_stride[kRank - 1]; /// stride (units of bytes) - I, J, K + int64_t workspace_stride; /// stride (units of bytes) between workspace + int64_t workspace_outer_stride; /// stride (units of bytes) between 'rows' of the workspace + int workspace_count; /// number of workspaces + + uint64_t inner_count; /// Number of elements in reduced index space + uint64_t outer_count; /// Number of elements in outer index space + + ElementOutput * destination; /// Pointer to output tensor of rank kReducedRank + ElementSource const * source; /// Pointer to source pointer of rank kRank + ReductionOp reduction_op; /// Reduction operator + ElementCompute reduction_identity; /// Identity element for reduction operator + ElementCompute *device_workspace; /// Pointer to device workspace for inter-CTA reductions + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + TensorReductionAffineStridedParams() { + + } + + /// Ctor + TensorReductionAffineStridedParams( + Coord extent_, ///< Extent of source tensor + ElementOutput * dst_ptr_, ///< Output tensor data + int64_t dst_stride_[], ///< Stride (units of elements) + ElementSource const * src_ptr_, ///< Source tensor data + int64_t src_stride_[], ///< Stride (units of elements) + ElementCompute *device_workspace_, ///< Pointer to device workspace for inter-CTA reductions + int64_t workspace_stride_, ///< Stride between workspaces + int workspace_count_, ///< Number of workspaces + ReductionOp reduction_op_, ///< Reduction operator + ElementCompute reduction_identity_ = ElementCompute() ///< Identity element for reduction operator + ): + extent(extent_), + inner_count(1), + outer_count(1), + destination(dst_ptr_), + source(src_ptr_), + device_workspace(device_workspace_), + workspace_outer_stride(0), + workspace_stride(workspace_stride_), + workspace_count(workspace_count_), + reduction_op(reduction_op_), + reduction_identity(reduction_identity_) { + + // Initialize divisors for fast div-mod + for (int p = 1; p < kRank; ++p) { + divmod[p - 1] = FastDivmodU64(uint64_t(extent[p])); + } + + int input_size_bits = sizeof_bits::value; + int output_size_bits = sizeof_bits::value; + + workspace_outer_stride = workspace_stride * workspace_count; + + // Compute strides in units of bytes + for (int p = 0; p < kReducedRank - 1; ++p) { + dst_stride[p] = dst_stride_[p] * output_size_bits / 8; + } + + for (int p = 0; p < kRank - 1; ++p) { + src_stride[p] = src_stride_[p] * input_size_bits / 8; + } + + // Compute number of elements in strided ranks + for (int p = 0; p < kReducedRank - 1; ++p) { + outer_count *= uint64_t(extent[p]); + } + + for (int p = 0; p < kInnerRank; ++p) { + inner_count *= uint64_t(extent[kReducedRank + p - 1]); + } + } +}; + +/// Kernel to reduce a tensor with affine layout over a set of ranks *EXCLUDING* the contiguous +/// rank. This leads to favorable vectorized memory accesses over the contiguous rank. +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineStrided { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory allocation used for reduction within the CTA + struct SharedStorage { + Array workspace; + }; + + /// Parameters structure + using Params = TensorReductionAffineStridedParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_inner_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose into coordinate + coord = CoordinateDecomposition(linear_idx, ¶ms.divmod[kReducedRank - 1]); + + // Compute linear offset + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kInnerRank; ++i) { + src_offset += params.src_stride[kReducedRank + i - 1] * coord[i]; + } + } + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + int64_t &src_offset, + uint64_t linear_idx) const { + + // Decompose linear coordinate + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute offset into tensors + dst_offset = 0; + src_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank - 1; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + src_offset += params.src_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ComputeFragment reduce_indices_( + Params const ¶ms, + ElementCompute *threadblock_workspace, + char const *src_byte_ptr) { + + NumericArrayConverter convert_source; + ReductionOp reduction_op(params.reduction_op); + + // Accumulated output + ComputeFragment identity_frag; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(identity_frag.size()); ++i) { + identity_frag[i] = params.reduction_identity; + } + + if (!params.inner_count) { + return identity_frag; + } + + ComputeFragment accumulator = identity_frag; + + // Compute the coordinate of the first access + int64_t src_byte_offset = 0; + Coord coord; + + uint64_t linear_idx = threadIdx.z + blockIdx.z * blockDim.z; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + + // Load the first vector + SourceFragment source_fragment[kBatchSize]; + + bool not_done = true; + + // Iterate over vectors in a linearized reduction index space + while (not_done) { + + bool guards[kBatchSize]; + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + + if (linear_idx < params.inner_count) { + source_fragment[b] = *reinterpret_cast(src_byte_ptr + src_byte_offset); + guards[b] = true; + } + else { + guards[b] = false; + not_done = false; + } + + linear_idx += blockDim.z * gridDim.z; + compute_inner_coord_and_offset_(params, coord, src_byte_offset, linear_idx); + } + + // Perform a batch of reduction operations + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (guards[b]) { + + auto cvt = convert_source(source_fragment[b]); + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + cvt); + } + } + }; + + // Optional reduction within a CTA + if (blockDim.z > 1) { + + // Linearized thread ID + int thread_idx = threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z); + + // all threads store to workspace + ComputeFragment *frag_ptr = reinterpret_cast(threadblock_workspace); + + frag_ptr[thread_idx] = accumulator; + + __syncthreads(); + + if (threadIdx.z == 0) { + // Load all additional block indices + for (int z = 1; z < blockDim.z; ++z) { + ComputeFragment frag = frag_ptr[thread_idx + z * blockDim.x * blockDim.y]; + + accumulator = cutlass::reduction::thread::detail::ApplyArrayOperator( + reduction_op, + accumulator, + frag); + } + } + + __syncthreads(); + } + + return accumulator; + } + +public: + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char const * src_byte_ptr = reinterpret_cast(params.source + coord_c); + char * dst_byte_ptr = nullptr; + + // If performing a reduction across CTAs, redirect output to device workspace + if (gridDim.z == 1) { + dst_byte_ptr = reinterpret_cast(params.destination + coord_c); + } + else { + dst_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); + } + + // If the C index is out of bounds, exit + if (coord_c >= params.extent[kRank - 1]) { + return; + } + + int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + int64_t src_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + if (gridDim.z == 1) { + + /// Complete the reduction with no workspace + while (idx_linear < params.outer_count) { + + ComputeFragment result; + + result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0) { + + // Convert to output type and store + NumericArrayConverter convert_output; + auto cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = + reinterpret_cast(cvt); + } + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while + } + else { + + /// Complete the reduction with a device workspace + while (idx_linear < params.outer_count) { + + ComputeFragment result; + + result = reduce_indices_( + params, + shared_storage.workspace.data(), + src_byte_ptr + src_byte_offset); + + // Store the result after possible final reduction within the CTA + if (threadIdx.z == 0) { + + int64_t byte_offset = + blockIdx.z * params.workspace_stride + idx_linear * params.workspace_outer_stride; + + // No conversion - store in compute type + *reinterpret_cast(dst_byte_ptr + byte_offset) = + reinterpret_cast(result); + } + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + src_byte_offset, + idx_linear); + + } // while (outer index) + } // if () + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to perform final reduction +template < + int Rank, ///< Rank of source tensor (e.g. NDHWC => 5) + int ReducedRank, ///< Rank of reduced tensor (includes contiguous, e.g. NC => 2) + typename ElementOutput, ///< Data type of output tensor + typename ElementSource, ///< Data type of source tensor + typename ReductionOp, ///< Reduction operator + int VectorLength = 1, ///< Vector length for memory + typename ElementCompute = ElementOutput, ///< Internal compute type - input type of reduction operation + int Threads = 256, ///< Number of participating threads + int BatchSize = 4 ///< Number of elements to load per batch +> +class TensorReductionAffineStridedFinal { +public: + + static int const kRank = Rank; + static int const kReducedRank = ReducedRank; + static int const kVectorLength = VectorLength; + static int const kInnerRank = kRank - kReducedRank; + static int const kThreads = Threads; + static int const kBatchSize = BatchSize; + using ComputeFragment = Array; + using SourceFragment = AlignedArray; + using OutputFragment = AlignedArray; + + /// Shared memory + struct SharedStorage { }; + + /// Parameters structure + using Params = TensorReductionAffineStridedParams< + Rank, + ReducedRank, + ElementOutput, + ElementSource, + ReductionOp, + VectorLength, + ElementCompute, + Threads, + BatchSize + >; + +private: + + /// Computes the coordinate and offset of a given linear index + CUTLASS_DEVICE + void compute_outer_coord_and_offset_( + Params const ¶ms, + Coord & coord, + int64_t &dst_offset, + uint64_t linear_idx) const { + + // Decompose linear index + coord = CoordinateDecomposition(linear_idx, params.divmod); + + // Compute tensor offset + dst_offset = 0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kReducedRank - 1; ++i) { + dst_offset += params.dst_stride[i] * coord[i]; + } + } + + /// Reduces over the reduction indices + CUTLASS_DEVICE + ComputeFragment reduce_indices_( + Params const ¶ms, + char *src_byte_ptr) { + + ReductionOp reduction_op(params.reduction_op); + + // Accumulated output + ComputeFragment identity_frag; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(identity_frag.size()); ++i) { + identity_frag[i] = params.reduction_identity; + } + + ComputeFragment accumulator = identity_frag; + ComputeFragment workspace_fragments[kBatchSize]; + + // Partially unrolled loop + for (int idx = 0; idx < params.workspace_count; idx += kBatchSize) { + + // Issue a batch of loads + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + if (idx + b < params.workspace_count) { + workspace_fragments[b] = + *reinterpret_cast(src_byte_ptr); + } + else { + workspace_fragments[b] = identity_frag; + } + src_byte_ptr += + params.workspace_stride; + } + + // Perform a reduction + CUTLASS_PRAGMA_UNROLL + for (int b = 0; b < kBatchSize; ++b) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorLength; ++i) { + accumulator[i] = reduction_op(accumulator[i], workspace_fragments[b][i]); + } + } + } + + return accumulator; + } + +public: + + // + // Methods + // + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + int coord_c = (blockIdx.x * blockDim.x + threadIdx.x) * kVectorLength; + + char * src_byte_ptr = reinterpret_cast(params.device_workspace + coord_c); + char * dst_byte_ptr = reinterpret_cast(params.destination + coord_c); + + // If the C index is out of bounds, exit + if (coord_c >= params.extent[kRank - 1]) { + return; + } + + int64_t idx_linear = blockIdx.y * blockDim.y + threadIdx.y; + + // Use modulo division to compute location + Coord outer_coord; + int64_t dst_byte_offset; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + + /// Complete the reduction + while (idx_linear < params.outer_count) { + + int64_t src_byte_offset = idx_linear * params.workspace_outer_stride; + + ComputeFragment result = reduce_indices_( + params, + src_byte_ptr + src_byte_offset); + + // Convert to output type and store + NumericArrayConverter convert_output; + auto cvt = convert_output(result); + + *reinterpret_cast(dst_byte_ptr + dst_byte_offset) = + reinterpret_cast(cvt); + + // Update indices and pointers + idx_linear += gridDim.y * blockDim.y; + + compute_outer_coord_and_offset_( + params, + outer_coord, + dst_byte_offset, + idx_linear); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..cc354df56a0fd83f0315370138fca729a2236d79 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduce.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic thread level reduction with specializations for Array. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/functional.h" + +namespace cutlass { +namespace reduction { +namespace thread { + +/// Structure to compute the thread level reduction +template +struct Reduce; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization of Reduce for "plus" (a functional operator) +template +struct Reduce< plus, T > { + + CUTLASS_HOST_DEVICE + T operator()(T lhs, T const &rhs) const { + plus _op; + return _op(lhs, rhs); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization of Reduce for Array +template +struct Reduce < plus, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &in) const { + + Array result; + Reduce< plus, T > scalar_reduce; + result.clear(); + + CUTLASS_PRAGMA_UNROLL + for (auto i = 0; i < N; ++i) { + result[0] = scalar_reduce(result[0], in[i]); + } + + return result; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specializations of Reduce for Array +template +struct Reduce < plus, Array > { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &input) { + + Array result; + + // If there is only 1 element - there is nothing to reduce + if( N ==1 ){ + + result[0] = input.front(); + + } else { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) + + __half result_d; + Array const *in_ptr_half = reinterpret_cast const *>(&input); + Array const *in_ptr_half2 = reinterpret_cast const *>(&input); + __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); + + // Set initial result = first half2, in case N==2 + __half2 tmp_result = x_in_half2[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N/2; ++i) { + + tmp_result = __hadd2(x_in_half2[i], tmp_result); + + } + + result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); + + // One final step is needed for odd "N" (to add the (N-1)th element) + if( N%2 ){ + + __half last_element; + Array tmp_last; + Array *tmp_last_ptr = &tmp_last; + tmp_last_ptr[0] = in_ptr_half[N-1]; + last_element = reinterpret_cast<__half const &>(tmp_last); + + result_d = __hadd(result_d, last_element); + + } + + Array *result_ptr = &result; + *result_ptr = reinterpret_cast &>(result_d); + + #else + + Reduce< plus, half_t > scalar_reduce; + result.clear(); + + CUTLASS_PRAGMA_UNROLL + for (auto i = 0; i < N; ++i) { + + result[0] = scalar_reduce(result[0], input[i]); + + } + + #endif + } + + return result; + + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specializations of Reduce for AlignedArray +template +struct Reduce < plus, AlignedArray > { + + CUTLASS_HOST_DEVICE + Array operator()(AlignedArray const &input) { + + Array result; + + // If there is only 1 element - there is nothing to reduce + if( N ==1 ){ + + result[0] = input.front(); + + } else { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) + + __half result_d; + AlignedArray const *in_ptr_half = reinterpret_cast const *>(&input); + AlignedArray const *in_ptr_half2 = reinterpret_cast const *>(&input); + __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2); + + // Set initial result = first half2, in case N==2 + __half2 tmp_result = x_in_half2[0]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N/2; ++i) { + + tmp_result = __hadd2(x_in_half2[i], tmp_result); + + } + + result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result)); + + // One final step is needed for odd "N" (to add the (N-1)th element) + if( N%2 ){ + + __half last_element; + AlignedArray tmp_last; + AlignedArray *tmp_last_ptr = &tmp_last; + tmp_last_ptr[0] = in_ptr_half[N-1]; + last_element = reinterpret_cast<__half const &>(tmp_last); + + result_d = __hadd(result_d, last_element); + + } + + Array *result_ptr = &result; + *result_ptr = reinterpret_cast &>(result_d); + + #else + + Reduce< plus, half_t > scalar_reduce; + result.clear(); + + CUTLASS_PRAGMA_UNROLL + for (auto i = 0; i < N; ++i) { + + result[0] = scalar_reduce(result[0], input[i]); + + } + + #endif + } + + return result; + + } +}; +} +} +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h new file mode 100644 index 0000000000000000000000000000000000000000..3792d332de65f19a1d30ba311d34073201176a3b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/thread/reduction_operators.h @@ -0,0 +1,235 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mixed-precision reduction +template < + typename ElementAccumulator_, + typename Element_, + int Count = 1 +> +struct ReduceAdd { + + // + // Type definitions + // + + using ElementAccumulator = ElementAccumulator_; + using Element = Element_; + static int const kCount = Count; + + using FragmentAccumulator = cutlass::Array; + using FragmentElement = cutlass::Array; + + struct Params { }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + ReduceAdd(Params params_ = Params()): params(params_) { } + + /// Operator + CUTLASS_HOST_DEVICE + FragmentAccumulator operator()( + FragmentAccumulator accumulator, + FragmentElement element) const { + + plus op; + + NumericArrayConverter< + ElementAccumulator, + Element, + kCount, + PreferredRoundingMode::kRound> converter; + + return op(accumulator, converter(element)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Special handling for binary operators +template +struct VectorizeArrayOperation { + + using ValueType = Array; + + CUTLASS_HOST_DEVICE + ValueType operator()( + ReductionOp const &reduction_op, + ValueType const &lhs, + ValueType const &rhs) const { + + ValueType result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = reduction_op(lhs[i], rhs[i]); + } + + return result; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ReduceArrayOperation { + + using ArrayType = Array; + + CUTLASS_HOST_DEVICE + Element operator()( + ReductionOp const &reduction_op, + ArrayType const &array) const { + + Element item = reduction_op(array[0], array[1]); + + CUTLASS_PRAGMA_UNROLL + for (int i = 2; i < N; ++i) { + item = reduction_op(item, array[i]); + } + + return item; + } +}; + +template +struct ReduceArrayOperation, uint1b_t, N> { + + using ArrayType = Array; + + CUTLASS_HOST_DEVICE + uint1b_t operator()( + logical_and const &reduction_op, + ArrayType const &array) const { + + uint8_t const *ptr = reinterpret_cast(&array); + bool item = false; + + CUTLASS_PRAGMA_UNROLL + for (int byte = 0; byte < (N + 7) / 8; ++byte) { + uint8_t bits = ptr[byte]; + item = (item || !bits); + } + + return uint1b_t{!item}; + } +}; + +template +struct ReduceArrayOperation, uint1b_t, N> { + + using ArrayType = Array; + + CUTLASS_HOST_DEVICE + uint1b_t operator()( + logical_and const &reduction_op, + ArrayType const &array) const { + + uint8_t const *ptr = reinterpret_cast(&array); + bool item = true; + + CUTLASS_PRAGMA_UNROLL + for (int byte = 0; byte < (N + 7) / 8; ++byte) { + uint8_t bits = ptr[byte]; + item = (item || bits); + } + + return uint1b_t{item}; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper function to infer template argument types +template +CUTLASS_HOST_DEVICE +Array ApplyArrayOperator( + ReductionOp const &reduction_op, + Array const &lhs, + Array const &rhs) { + + VectorizeArrayOperation vectorize_op; + + return vectorize_op(reduction_op, lhs, rhs); +} + +/// Helper to reduce an array +template +Element ReduceArray(ReductionOp const &reduction_op, Array const &array) { + ReduceArrayOperation reduce_array_op; + + return reduce_array_op(reduction_op, array); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace reduction +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..bbabaed2736cac7043671f10e9813a9a48b1916c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/reduction/threadblock_swizzle.h @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +/*! \file +\brief Defies functors for mapping blockIdx to partitions of the batched reduction computation. +*/ +#pragma once +#include "cutlass/coord.h" + +namespace cutlass { +namespace reduction { +struct DefaultBlockSwizzle { + /// Ctor + CUTLASS_HOST_DEVICE DefaultBlockSwizzle() {} + + /// Swizzle the block index. + CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } + + /// + CUTLASS_HOST_DEVICE dim3 get_grid_layout(Coord<3> const &problem_size, + Coord<3> const &OutputTile) { + assert(OutputTile[0] == 1 && OutputTile[1] == 1); + assert((problem_size[0] * problem_size[1] * problem_size[2]) % OutputTile[2] == 0); + dim3 grid; + grid.x = problem_size[0] * problem_size[1] * problem_size[2] + / OutputTile[2] ; + return grid; + } + + /// + CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &SubTile) { + assert(SubTile[0] == 1 && SubTile[1] == 1); + dim3 block = swizzle(); + Coord<3> threadblock_offset = + make_Coord(0, 0, block.x * SubTile[2]); + return threadblock_offset; + } +}; +} // namespace reduction +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h new file mode 100644 index 0000000000000000000000000000000000000000..68bdb26e38b1a54843eb4883833ad6b8708f0aff --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/relatively_equal.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Performs comparison between two elements with support for floating-point comparisons. +*/ + +#pragma once + +#include "numeric_types.h" +#include "complex.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +bool relatively_equal(T a, T b, U epsilon, U nonzero_floor); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// This floating-point comparison function implements the method described in +// +// https://floating-point-gui.de/errors/comparison/ +// +template +CUTLASS_HOST_DEVICE +bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { + +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + T diff = abs(a - b); + T zero = T(0); + + if (a == b) { + return true; + } + else if (a == zero || b == zero || (abs_A + abs_B) < nonzero_floor) { + return diff < epsilon * nonzero_floor; + } + + return diff < epsilon * (abs_A + abs_B); +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(bool a, bool b, bool, bool) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int2b_t a, int2b_t b, int2b_t, int2b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int4b_t a, int4b_t b, int4b_t, int4b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint4b_t a, uint4b_t b, uint4b_t, uint4b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int8_t a, int8_t b, int8_t, int8_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint8_t a, uint8_t b, uint8_t, uint8_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int16_t a, int16_t b, int16_t, int16_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint16_t a, uint16_t b, uint16_t, uint16_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int32_t a, int32_t b, int32_t, int32_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint32_t a, uint32_t b, uint32_t, uint32_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int64_t a, int64_t b, int64_t, int64_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint64_t a, uint64_t b, uint64_t, uint64_t) { + return (a == b); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e4m3_t a, float_e4m3_t b, float_e4m3_t epsilon, float_e4m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e5m2_t a, float_e5m2_t b, float_e5m2_t epsilon, float_e5m2_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(half_t a, half_t b, half_t epsilon, half_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal( + bfloat16_t a, + bfloat16_t b, + bfloat16_t epsilon, + bfloat16_t nonzero_floor) { + + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal( + tfloat32_t a, + tfloat32_t b, + tfloat32_t epsilon, + tfloat32_t nonzero_floor) { + + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float a, float b, float epsilon, float nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(double a, double b, double epsilon, double nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template +CUTLASS_HOST_DEVICE +bool relatively_equal(complex a, complex b, T epsilon, T nonzero_floor) { +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + T diff = abs(a - b); + complex zero = complex{T{}, T{}}; + + if (a == b) { + return true; + } + else if (a == zero || b == zero || diff < nonzero_floor) { + return diff < epsilon * nonzero_floor; + } + + return diff < epsilon * (abs_A + abs_B); +} + +template +CUTLASS_HOST_DEVICE +bool relatively_equal(complex a, complex b, complex epsilon, complex nonzero_floor) { +#if defined(__CUDACC_RTC__) + using cuda::std::abs; +#else + using std::abs; +#endif + + T abs_A = abs(a); + T abs_B = abs(b); + complex diff = a - b; + T abs_diff = abs(diff); + complex zero = complex{T{}, T{}}; + + if (a == b) { + return true; + } + else if (a == zero || b == zero || abs_diff < abs(nonzero_floor)) { + return abs_diff < abs(epsilon * nonzero_floor); + } + + return abs_diff < abs(epsilon) * (abs_A + abs_B); +} + + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e2m3_t a, float_e2m3_t b, float_e2m3_t epsilon, float_e2m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e3m2_t a, float_e3m2_t b, float_e3m2_t epsilon, float_e3m2_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_e2m1_t a, float_e2m1_t b, float_e2m1_t epsilon, float_e2m1_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_ue8m0_t a, float_ue8m0_t b, float_ue8m0_t epsilon, float_ue8m0_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(float_ue4m3_t a, float_ue4m3_t b, float_ue4m3_t epsilon, float_ue4m3_t nonzero_floor) { + return detail::relatively_equal_float(a, b, epsilon, nonzero_floor); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/semaphore.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/semaphore.h new file mode 100644 index 0000000000000000000000000000000000000000..09a0a1a4572775bbdbdba63a160952e35fef2c20 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/semaphore.h @@ -0,0 +1,118 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implementation of a CTA-wide semaphore for inter-CTA synchronization. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CTA-wide semaphore for inter-CTA synchronization. +class Semaphore { +public: + + int *lock; + bool wait_thread; + int state; + +public: + + /// Implements a semaphore to wait for a flag to reach a given value + CUTLASS_HOST_DEVICE + Semaphore(int *lock_, int thread_id): + lock(lock_), + wait_thread(thread_id < 0 || thread_id == 0), + state(-1) { + + } + + /// Permit fetching the synchronization mechanism early + CUTLASS_DEVICE + void fetch() { + if (wait_thread) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + #else + asm volatile ("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + #endif + } + } + + /// Gets the internal state + CUTLASS_DEVICE + int get_state() const { + return state; + } + + /// Waits until the semaphore is equal to the given value + CUTLASS_DEVICE + void wait(int status = 0) { + while( __syncthreads_and(state != status) ) { + fetch(); + } + + __syncthreads(); + } + + /// Updates the lock with the given result + CUTLASS_DEVICE + void release(int status = 0) { + __syncthreads(); + + if (wait_thread) { + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile ("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); + #else + asm volatile ("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); + #endif + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h new file mode 100644 index 0000000000000000000000000000000000000000..6e98cdc3886b06626ea7d003122d62078f7767b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/subbyte_reference.h @@ -0,0 +1,1388 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Provides a mechanism for packing and unpacking elements smaller than one byte +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/fast_math.h" + +namespace cutlass { + +namespace detail { +// This is an implementation detail of cutlass::SubbyteReference and. +// cutlass::HostTensor. For a given logical element type Element, +// and its corresponding storage (physical) element type StorageUnit, +// it computes quantities that help with managing allocations. +// +// CUTLASS uses a hidden "ContainerUnitType" or StorageUnit type to support +// packed arrays of subbyte types such as int4. Element is the "logical" type +// for computations, while CUTLASS uses StorageUnit as the element type +// of a packed array of Element. If Element is not a subbyte type, +// then the corresponding StorageUnit type is just Element itself. +// +// The ContainerType is always calculated as an array StorageUnit type (the StorageUnit +// is always a byte for subbyte types), +// and its number of bits is the lcm of the subbyte type's number of bits and 8. +// Below are some examples for different subbyte types. +// +// * Subbyte Type=int2, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) +// * Subbyte Type=int4, ContainerType=StorageUnit[1] (StorageUnit=uint8_t) +template +struct StorageContainerCalculator { + // kContainerTypeNumBits: The number of bits needed for ContainerType + static constexpr int kContainerTypeNumBits = (sizeof_bits::value < 8) ? cutlass::lcm_cxx11(sizeof_bits::value, sizeof_bits::value) : sizeof_bits::value; + static_assert(kContainerTypeNumBits % sizeof_bits::value == 0, "The bits of ContainerType should be divisible by the element's number of bits"); + // kContainerTypeNumLogicalElements: The number of logical Element instance(s) that can be stored per ContainerType instance + static constexpr int kContainerTypeNumLogicalElements = kContainerTypeNumBits / sizeof_bits::value; + /// 3. kContainerTypeNumBytes: The number of bytes per ContainerType instance + static constexpr int kContainerTypeNumBytes = kContainerTypeNumBits / 8; + /// 4. kContainerTypeNumBytes: The number of base StorageUnit in the ContainerType + static constexpr int kContainerTypeNumStorageUnit = kContainerTypeNumBits / sizeof_bits::value; + + static_assert(kContainerTypeNumBits != 0, "kContainerTypeNumBits can not be zero"); + static_assert(kContainerTypeNumLogicalElements != 0, "kContainerTypeNumLogicalElements can not be zero"); + static_assert(kContainerTypeNumBytes != 0, "kContainerTypeNumBytes can not be zero"); +}; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This class provides a mechanism for packing and unpacking elements smaller than one byte. It +/// assumes these sub-byte elements are packed in a traditional C++ numeric type. +/// +/// The intended application is to provide a mechanism to indirectly reference elements in +/// memory or Array<> objects whose addresses cannot otherwise be taken since they are smaller +/// than one byte. +/// +/// Supports basic pointer arithmetic: +/// +/// Example: +/// +/// int4b_t *ptr = ...; +/// +/// SubbyteReference ref = ptr; +/// ref += 15; +/// +/// int4b_t x = ref; // load an int4b_t +/// ref = x + 2_s4; // perform arithmetic on int4b_t and then store +/// +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ = uint8_t, /// Underlying storage type. Must be able to hold an integer + /// number of objects of type Element. + class = void +> +class ConstSubbyteReference { +public: + + using Element = Element_; + using Storage = Storage_; + using StoragePointer = Storage const *; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than Storage."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "Storage must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask + Storage const kMask = + ((sizeof_bits::value < sizeof_bits::value) ? + (Storage(1) << sizeof_bits::value) - Storage(1) : + ~Storage(0)); + +private: + + /// Pointer to array containing element + StoragePointer ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + +public: + + CUTLASS_HOST_DEVICE + ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element const *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element *ptr = nullptr + ): ConstSubbyteReference(ptr, 0) { } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + StoragePointer storage_pointer() const { + return ptr_; + } + + /// Gets element offset within storage vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-=(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(ConstSubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ = /// Underlying storage type. Must be able to hold an integer + /// number of objects of type Element. + +#if defined(__CUDA_ARCH__) /// Default size depends on width of atomicCas() overloads. + #if (__CUDA_ARCH__ >= 700) /// + uint16_t + #else + uint32_t + #endif +#else + uint8_t +#endif + , + class = void +> +class SubbyteReference { +public: + + using Element = Element_; + using Storage = Storage_; + using StoragePointer = Storage *; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than Storage."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "Storage must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask + Storage const kMask = + ((sizeof_bits::value < sizeof_bits::value) ? + (Storage(1) << sizeof_bits::value) - Storage(1) : + ~Storage(0)); + +private: + + /// Pointer to array containing element + StoragePointer ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + +public: + + CUTLASS_HOST_DEVICE + SubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr = nullptr + ): SubbyteReference(ptr, 0) { } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + StoragePointer storage_pointer() const { + return ptr_; + } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + Element * operator&() const { + return reinterpret_cast(ptr_); + } + + /// Gets element offset within storage vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + uint8_t const* byte_ptr = reinterpret_cast(ptr_); + // Convert offset in elements to offset in bytes + constexpr int elements_per_byte = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; + byte_ptr += offset_ / elements_per_byte; + // Offset of element within a byte + int byte_offset = offset_ % elements_per_byte; + uint8_t item = uint8_t((*byte_ptr >> (byte_offset * cutlass::sizeof_bits::value)) & kMask); + return reinterpret_cast(item); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference & set(Element const &x) { + + Storage item = (reinterpret_cast(x) & kMask); + Storage kUpdateMask = Storage(~(kMask << (offset_ * cutlass::sizeof_bits::value))); + Storage new_bits = Storage(item << (offset_ * cutlass::sizeof_bits::value)); + +#if defined(__CUDA_ARCH__) + + // + // Homebrew read-modify-write + // + Storage original; + Storage updated; + + do { + + original = (*ptr_); + + updated = Storage((original & kUpdateMask) | new_bits); + + original = atomicCAS(ptr_, original, updated); + + } while (updated != original); + +#else + + Storage original = (*ptr_); + Storage updated = Storage((original & kUpdateMask) | new_bits); + *ptr_ = updated; + +#endif + + return *this; + } + + //// + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(Element const & x) { + return set(x); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(SubbyteReference const & x) { + return set(x.get()); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=( + ConstSubbyteReference const &x) { + return set(x.get()); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-=(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(SubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template using _war = T; +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ /// Underlying basic storage type. +> +class SubbyteReference::value % sizeof_bits::value != 0>::type> { +public: + + using Element = Element_; + /// Note: It's possible that StorageUnit is not divisible by Element. + /// For example, an Element instance might be stored across 2 StorageUnit instances. + /// Thus, CUTLASS needs a storage vector to hold an integer number of Element instances. + + using StorageUnit = Storage_; +private: + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; +public: + static int const kBitsStoredVec = StorageContainerCalculator::kContainerTypeNumBits; + static int const kNumStorageUnitPerStoredVec = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; + using StorageVecPointer = StorageVec *; + + using CudaAtomicType = typename platform::conditional< + sizeof_bits::value == 16, + uint32_t, + uint64_t + >::type; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than StorageVec."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "StorageVec must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask for storage unit. + StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); + + /// Pointer to array containing element + _war ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + + /// Element may be stored across 2 storage unit. + /// Low storage unit index in StorageVec + /// High storage unit index in StorageVec + int low_storage_unit_idx_; + int high_storage_unit_idx_; + + /// Full Mask to extract the entire element + uint64_t full_element_mask_; + + /// Mask to extract the Element from Low storage unit and High storage unit. + StorageUnit low_storage_mask_; + StorageUnit high_storage_mask_; + + /// Start bit index inside the storage unit. + int start_bit_idx_; + +private: + + CUTLASS_HOST_DEVICE + void update_element_status() { + int num_bits = offset_ * sizeof_bits::value; + + start_bit_idx_ = num_bits % sizeof_bits::value; + + low_storage_unit_idx_ = num_bits / sizeof_bits::value; + high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value + ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; + + full_element_mask_ = uint64_t(kMask) << start_bit_idx_; + low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); + high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); + } + +public: + + CUTLASS_HOST_DEVICE + SubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + + update_element_status(); + } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr = nullptr + ): SubbyteReference(ptr, 0) { } + + /// Gets StorageVec pointer + CUTLASS_HOST_DEVICE + StorageVecPointer storage_pointer() const { + return ptr_; + } + + /// Gets StorageVec pointer + CUTLASS_HOST_DEVICE + Element * operator&() const { + return reinterpret_cast(ptr_); + } + + /// Gets element offset within StorageVec vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; + StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; + + uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; + uint8_t result = uint8_t(full_item >> start_bit_idx_); + + return reinterpret_cast(result); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference & set(Element const &x) { + + uint64_t item = static_cast((reinterpret_cast(x) & kMask)) << start_bit_idx_; + + StorageUnit low_new_bits = StorageUnit(item & ~StorageUnit(0)); + StorageUnit high_new_bits = StorageUnit(item >> sizeof_bits::value); + + StorageUnit const kLowUpdateMask = StorageUnit((~full_element_mask_) & (~StorageUnit(0))); + StorageUnit const kHighUpdateMask = StorageUnit(((~full_element_mask_) >> sizeof_bits::value) & (~StorageUnit(0))); + +#if defined(__CUDA_ARCH__) + // + // Homebrew read-modify-write + // + if(high_storage_unit_idx_ != low_storage_unit_idx_){ + /// Only need update 2 storage unit at once. + /// consider misaligned address issue, we need to do atomicCAS twice + StorageUnit original_low_bits, original_high_bits, update_low_bits, update_high_bits; + do { + original_low_bits = ((*ptr_)[low_storage_unit_idx_]); + update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits; + original_low_bits = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original_low_bits, update_low_bits); + } while (update_low_bits != original_low_bits); + do { + original_high_bits = ((*ptr_)[high_storage_unit_idx_]); + update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits; + original_high_bits = atomicCAS(&((*ptr_)[high_storage_unit_idx_]), original_high_bits, update_high_bits); + } while (update_high_bits != original_high_bits); + } + else { + /// Only need update 1 storage unit. + StorageUnit original, updated; + do { + original = ((*ptr_)[low_storage_unit_idx_]); + + updated = (original & kLowUpdateMask) | low_new_bits; + + original = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original, updated); + + } while (updated != original); + } +#else + + + StorageUnit update_low_bits = ((*ptr_)[low_storage_unit_idx_] & kLowUpdateMask) | low_new_bits; + StorageUnit update_high_bits = ((*ptr_)[high_storage_unit_idx_] & kHighUpdateMask) | high_new_bits; + + (*ptr_)[low_storage_unit_idx_] = update_low_bits; + + if(low_storage_unit_idx_ != high_storage_unit_idx_) + (*ptr_)[high_storage_unit_idx_] = update_high_bits; +#endif + + return *this; + } + + //// + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(Element const & x) { + return set(x); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(SubbyteReference const & x) { + return set(x.get()); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=( + ConstSubbyteReference const &x) { + return set(x.get()); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-=(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(SubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +template using _war = T; +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ /// Underlying storage type. Must be able to hold an integer +> +class ConstSubbyteReference::value % sizeof_bits::value != 0>::type> { +public: + + using Element = Element_; + ///! Note: Storage unit could not be divisibale by Element, + /// Type element may be stored across 2 storage units, so need a storage vector to hold integer + /// number of objects of type Element. + using StorageUnit = Storage_; + static int const kBitsStoredVec = cutlass::lcm_cxx11(sizeof_bits::value, sizeof_bits::value); + static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits::value; + + using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; + using StorageVecPointer = StorageVec const *; + + using CudaAtomicType = typename platform::conditional< + sizeof_bits::value == 16, + uint32_t, + uint64_t + >::type; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than StorageVec."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "StorageVec must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask for storage unit. + StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); + + /// Pointer to array containing element + _war ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + + /// Element may be stored across 2 storage unit. + /// Low storage unit index in StorageVec + /// High storage unit index in StorageVec + int low_storage_unit_idx_; + int high_storage_unit_idx_; + + /// Full Mask to extract the entire element + uint64_t full_element_mask_; + + /// Mask to extract the Element from Low storage unit and High storage unit. + StorageUnit low_storage_mask_; + StorageUnit high_storage_mask_; + + /// Start bit index inside the storage unit. + int start_bit_idx_; + +private: + + CUTLASS_HOST_DEVICE + void update_element_status() { + int num_bits = offset_ * sizeof_bits::value; + + start_bit_idx_ = num_bits % sizeof_bits::value; + + low_storage_unit_idx_ = num_bits / sizeof_bits::value; + high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value + ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; + + full_element_mask_ = uint64_t(kMask) << start_bit_idx_; + low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); + high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); + } + +public: + + CUTLASS_HOST_DEVICE + ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element const *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + + update_element_status(); + } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element *ptr = nullptr + ): ConstSubbyteReference(ptr, 0) { } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + StorageVecPointer storage_pointer() const { + return ptr_; + } + + /// Gets element offset within storage vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; + StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; + + uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; + uint8_t result = uint8_t(full_item >> start_bit_idx_); + + return reinterpret_cast(result); + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-=(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(ConstSubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template ::value < 8)> +struct ReferenceFactory; + +template +struct ReferenceFactory { + + ///! Number of elements per storage vector + static int const kElementsPerVector = 1; + + CUTLASS_HOST_DEVICE + static Element &get(Element *ptr, int64_t offset) { + return ptr[offset]; + } + + CUTLASS_HOST_DEVICE + static Element const &get(Element const *ptr, int64_t offset) { + return ptr[offset]; + } + + CUTLASS_HOST_DEVICE + static Element *add_pointer_offset(Element *ptr, int64_t offset) { + return ptr + offset; + } + + CUTLASS_HOST_DEVICE + static Element const *add_pointer_offset(Element const *ptr, int64_t offset) { + return ptr + offset; + } +}; + +template +struct ReferenceFactory { + + // + // Static methods + // + + CUTLASS_HOST_DEVICE + static SubbyteReference get(Element *ptr, int64_t offset) { + return SubbyteReference(ptr, offset); + } + + CUTLASS_HOST_DEVICE + static ConstSubbyteReference get(Element const *ptr, + int64_t offset) { + return ConstSubbyteReference(ptr, offset); + } + + /// Helper to add an offset in number of elements, assuming this offset is divisible + /// by the vector size. + CUTLASS_HOST_DEVICE + static Element *add_pointer_offset(Element *ptr, int64_t offset_in_elements) { + return &SubbyteReference(ptr, offset_in_elements); + } + + /// Helper to add an offset in number of elements, assuming this offset is divisible + /// by the vector size. + CUTLASS_HOST_DEVICE + static Element const *add_pointer_offset(Element const *ptr, int64_t offset_in_elements) { + return &ConstSubbyteReference(ptr, offset_in_elements); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h new file mode 100644 index 0000000000000000000000000000000000000000..a124d395cf2222331e0ceb160271b1621688fd6f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_coord.h @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a canonical coordinate for rank=4 tensors offering named indices. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a canonical 4D coordinate used by tensor operations. +struct Tensor4DCoord : public Coord<4> { + + /// Base class + using Base = Coord<4>; + + /// Index type + using Index = typename Base::Index; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Batch dimension + static int const kN = 0; + + /// Height dimension + static int const kH = 1; + + /// Width dimension + static int const kW = 2; + + /// Channels dimension + static int const kC = 3; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Tensor4DCoord() { } + + /// Constructs from Coord<4> + CUTLASS_HOST_DEVICE + Tensor4DCoord(Coord<4> const &coord): Base(coord) { } + + /// Helper to construct from N, H, W, and C. + CUTLASS_HOST_DEVICE + Tensor4DCoord(Index n, Index h, Index w, Index c): Base(make_Coord(n, h, w, c)) { } + + /// Helper to construct from N, H, W, and C, which are LongIndex type + CUTLASS_HOST_DEVICE + Tensor4DCoord(LongIndex n, LongIndex h, LongIndex w, LongIndex c) + : Base(make_Coord(Index(n), Index(h), Index(w), Index(c))) { } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & n() const { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index const & h() const { return this->at(kH); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index & h() { return this->at(kH); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index const & w() const { return this->at(kW); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index & w() { return this->at(kW); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index const & c() const { return this->at(kC); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index & c() { return this->at(kC); } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + Tensor4DCoord operator+(Base const& b) const { + return Tensor4DCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + Tensor4DCoord operator-(Base const& b) const { + return Tensor4DCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + Tensor4DCoord operator*(Base const& b) const { + return Tensor4DCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + Tensor4DCoord operator/(Base const& b) const { + return Tensor4DCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + Tensor4DCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a canonical 5D coordinate used by tensor operations. +struct Tensor5DCoord : public Coord<5> { + + /// Base class + using Base = Coord<5>; + + /// Index type + using Index = typename Base::Index; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Batch dimension + static int const kN = 0; + + /// Depth dimension + static int const kD = 1; + + /// Height dimension + static int const kH = 2; + + /// Width dimension + static int const kW = 3; + + /// Channels dimension + static int const kC = 4; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Tensor5DCoord() { } + + /// Constructs from Coord<5> + CUTLASS_HOST_DEVICE + Tensor5DCoord(Coord<5> const &coord): Base(coord) { } + + /// Helper to construct from N, D, H, W, and C. + CUTLASS_HOST_DEVICE + Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { } + + /// Helper to construct from N, D, H, W, and C, which are LongIndex type + CUTLASS_HOST_DEVICE + Tensor5DCoord(LongIndex n, LongIndex d, LongIndex h, LongIndex w, LongIndex c) + : Base(make_Coord(Index(n), Index(d), Index(h), Index(w), Index(c))) { } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & n() const { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & d() const { return this->at(kD); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & d() { return this->at(kD); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index const & h() const { return this->at(kH); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index & h() { return this->at(kH); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index const & w() const { return this->at(kW); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index & w() { return this->at(kW); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index const & c() const { return this->at(kC); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index & c() { return this->at(kC); } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + Tensor5DCoord operator+(Base const& b) const { + return Tensor5DCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + Tensor5DCoord operator-(Base const& b) const { + return Tensor5DCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + Tensor5DCoord operator*(Base const& b) const { + return Tensor5DCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + Tensor5DCoord operator/(Base const& b) const { + return Tensor5DCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h new file mode 100644 index 0000000000000000000000000000000000000000..fc467499996a00645b0a936efe741ece2092fb90 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref.h @@ -0,0 +1,419 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides, bounds, and a pointer to tensor data. +*/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/platform/platform.h" +#include "cutlass/subbyte_reference.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default layout function from coordinates in a tensor's index space into the n-D array held +/// in memory. +/// +/// All layout functions must define at least the members shown in IdentityTensorLayout<>. +template +class IdentityTensorLayout { +public: + /// Logical rank of tensor + static int const kRank = Rank; + + /// Rank of stride vector + static int const kStrideRank = Rank; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using TensorCoord = Coord; + + /// Stride vector + using Stride = Coord; + +private: + + // + // Data members + // + + /// Stride data member + Stride stride_; + +public: + + // + // Methods + // + + CUTLASS_HOST_DEVICE + IdentityTensorLayout(Stride const &stride = Stride()): stride_(stride) { } + + /// Returns the offset of a coordinate in linear memory + CUTLASS_HOST_DEVICE + LongIndex operator()(Coord const &coord) const { + return coord.dot(stride_); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &size) const { + int idx = stride_.max_dim_index(); + return stride_[idx] * size[idx]; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank + and layout within memory. A TensorRef combines a pointer and a Layout concept + + Examples: + + (These examples use helpers for matrix layouts defined in cutlass/layout/matrix.h) + + 1. Column-major matrix may be represented as a rank=2 tensor: + + TensorRef A(ptr_A, ldm); + + 2. Row-major matrix may be represented as a rank=2 tensor: + + TensorRef B(ptr_A, ldm); + + 3. An interleaved matrix may be represented as a rank=2 tensor: + + TensorRef > C; + + 4. A helper exists to define a TensorRef for a contiguous matrix whose layout + is not known at compile time. + + int ldm; // leading dimension + layout::Matrix kind; // Could be layout::Matrix::kRowMajor or layout::Matrix::kColumnMajor + + + TensorRef E(ptr_E, {ldm, kind}); + +*/ +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class TensorRef { + public: + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Reference type to an element + using Reference = typename platform::conditional< + sizeof_bits::value >= 8, + Element &, + SubbyteReference + >::type; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// TensorRef to constant data + using ConstTensorRef = TensorRef< + typename platform::remove_const::type const, + Layout>; + + /// TensorRef to non-constant data + using NonConstTensorRef = TensorRef< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// Pointer + Element* ptr_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + public: + + // + // Methods + // + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRef(): ptr_(nullptr) { + + } + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRef( + Element *ptr, ///< pointer to start of tensor + Layout const &layout ///< layout object containing stride and mapping function + ): + ptr_(ptr), layout_(layout) { + + } + + /// Converting constructor from TensorRef to non-constant data. + template + CUTLASS_HOST_DEVICE + TensorRef( + NonConstTensorRef const &ref, ///< TensorRef to non-const data + ///SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const + _Magic magic = (typename platform::enable_if< ! platform::is_same >::value, _Magic>::type)0 + ): + ptr_(ref.data()), layout_(ref.layout()) { } + + /// Returns a reference to constant-valued tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(ptr_, layout_); + } + + CUTLASS_HOST_DEVICE + NonConstTensorRef non_const_ref() const { + return NonConstTensorRef(const_cast::type *>(ptr_), layout_); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr = nullptr) { + ptr_ = ptr; + } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout) { + ptr_ = ptr; + layout_ = layout; + } + + /// Returns true if the TensorRef is non-null + CUTLASS_HOST_DEVICE + bool good() const { + return ptr_ != nullptr; + } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * data() const { return ptr_; } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Reference data(LongIndex idx) const { + return ReferenceFactory::type, + (sizeof_bits::value < 8)>::get(ptr_, idx); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout & layout() { + return layout_; + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + typename Layout::Stride::Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + typename Layout::Stride::Index & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + CUTLASS_HOST_DEVICE + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference operator[](TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRef & add_pointer_offset(LongIndex offset_) { + ptr_ = ReferenceFactory::type, + (sizeof_bits::value < 8)>::add_pointer_offset(ptr_, offset_); + return *this; + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRef & add_coord_offset(TensorCoord const &coord) { + add_pointer_offset(offset(coord)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef operator+(TensorCoord const& b) const { + TensorRef result(*this); + result.add_coord_offset(b); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef & operator+=(TensorCoord const& b) { + add_coord_offset(b); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef operator-(TensorCoord const& b) const { + TensorRef result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef & operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } +}; + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +TensorRef make_TensorRef(Element *ptr, Layout const &layout) { + return TensorRef(ptr, layout); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations to handle degenerate and sub-byte cases. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +bool TensorRef_aligned(TensorRef const &ref, int alignment) { + + int const kStrideRank = Layout::kStrideRank; + + if (reinterpret_cast(ref.data()) % alignment) { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + if (ref.stride(i) % alignment) { + return false; + } + } + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..9ba3a2308081e8c4b11d18cb8125ec7943e534f0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_ref_planar_complex.h @@ -0,0 +1,374 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides, bounds, and a pointer to tensor data. +*/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct PlanarComplexReference { + + // + // Type definitions + // + + using Element = Element_; + using ComplexElement = complex; + + // + // Data members + // + + Element *real; + Element *imag; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + PlanarComplexReference( + Element *real_ = nullptr, + Element *imag_ = nullptr + ): + real(real_), imag(imag_) { } + + /// Loads the complex element + CUTLASS_HOST_DEVICE + operator complex() const { + return complex{*real, *imag}; + } + + /// Stores a complex element to the location pointed to by the reference + CUTLASS_HOST_DEVICE + PlanarComplexReference &operator=(complex const &rhs) { + *real = rhs.real(); + *imag = rhs.imag(); + return *this; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank + and layout within memory. A TensorRef combines a pointer and a Layout concept + +*/ +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class TensorRefPlanarComplex { + public: + /// Data type of individual access + using Element = Element_; + + /// Complex element type + using ComplexElement = complex; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + static_assert(sizeof_bits::value >= 8, + "Planar complex not suitable for subbyte elements at this time"); + + /// Reference type to an element + using Reference = PlanarComplexReference; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// TensorRef to constant data + using ConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorRef to non-constant data + using NonConstTensorRef = TensorRefPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// Pointer + Element* ptr_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + /// Offset to imaginary part + LongIndex imaginary_stride_; + + public: + + // + // Methods + // + + /// Constructs a TensorRef with a pointer and layout object. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + Element *ptr = nullptr, ///< pointer to start of tensor + Layout const &layout = Layout(), ///< layout object containing stride and mapping function + LongIndex imaginary_stride = 0 + ): + ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex( + NonConstTensorRef const &ref ///< TensorRef to non-const data + ): + ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { } + + /// Returns a reference to constant-valued tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(ptr_, layout_, imaginary_stride_); + } + + CUTLASS_HOST_DEVICE + NonConstTensorRef non_const_ref() const { + return NonConstTensorRef( + const_cast::type *>(ptr_), + layout_, + imaginary_stride_); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) { + ptr_ = ptr; + imaginary_stride_ = imaginary_stride; + } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) { + ptr_ = ptr; + layout_ = layout; + imaginary_stride_ = imaginary_stride; + } + + /// Returns true if the TensorRef is non-null + CUTLASS_HOST_DEVICE + bool good() const { + return ptr_ != nullptr; + } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * data() const { return ptr_; } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Element * imaginary_data() const { return ptr_ + imaginary_stride_; } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Reference data(LongIndex idx) const { + return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout & layout() { + return layout_; + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + Layout layout() const { + return layout_; + } + + /// Gets the stride to an imaginary element + LongIndex imaginary_stride() const { + return imaginary_stride_; + } + + /// Gets the stride to an imaginary element + LongIndex &imaginary_stride() { + return imaginary_stride_; + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + CUTLASS_HOST_DEVICE + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + CUTLASS_HOST_DEVICE + Index & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + CUTLASS_HOST_DEVICE + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference operator[](TensorCoord const& coord) const { + return data(offset(coord)); + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) { + ptr_ += offset_; + return *this; + } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) { + add_pointer_offset(offset(coord)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator+(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_coord_offset(b); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator+=(TensorCoord const& b) { + add_coord_offset(b); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex operator-(TensorCoord const& b) const { + TensorRefPlanarComplex result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRefPlanarComplex & operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_real() const { + return cutlass::TensorRef(data(), layout()); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorRef ref_imag() const { + return cutlass::TensorRef(imaginary_data(), layout()); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE +TensorRefPlanarComplex make_TensorRefPlanarComplex( + Element *ptr, + Layout const &layout, + int64_t imaginary_stride) { + + return TensorRefPlanarComplex(ptr, layout, imaginary_stride); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h new file mode 100644 index 0000000000000000000000000000000000000000..d669443abd8b5b246a9d2aaf2ce4dd91f782f948 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view.h @@ -0,0 +1,297 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides and a pointer to tensor data. + + TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, + it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from + data storage and is therefore lightweight and may be embedded in larger tensor objects or + memory structures. + + See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to + linear memory. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Data type of element stored within tensor + typename Element_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename Layout_ +> +class TensorView : public TensorRef { + public: + + /// Base tensor reference + using Base = cutlass::TensorRef; + + /// Mapping function from logical coordinate to internal n-D array + using Layout = Layout_; + + /// TensorRef pointing to constant memory + using ConstTensorRef = typename Base::ConstTensorRef; + + /// Underlying TensorRef type + using TensorRef = Base; + + /// Data type of individual access + using Element = Element_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Coordinate in storage n-D array + using Stride = typename Layout::Stride; + + /// TensorView pointing to constant memory + using ConstTensorView = TensorView< + typename platform::remove_const::type const, + Layout>; + + /// TensorView pointing to non-constant memory + using NonConstTensorView = TensorView< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// View extent + TensorCoord extent_; + + public: + + // + // Methods + // + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorView() { } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorView( + Element *ptr, ///< pointer to start of tensor + Layout const &layout, ///< layout object containing stride and mapping function + TensorCoord const &extent ///< size of the view in logical coordinates + ): + Base(ptr, layout), extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorView( + TensorRef const &ref, ///< pointer and layout object referencing a tensor + TensorCoord const &extent ///< logical size of tensor + ): + Base(ref), extent_(extent) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorView( + NonConstTensorView const &view ///< TensorView to non-const data + ): + Base(view), extent_(view.extent_) { } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, TensorCoord const &extent) { + Base::reset(ptr, layout); + this->resize(extent); + } + + /// Updates the pointer + CUTLASS_HOST_DEVICE + void reset(Element* ptr) { + Base::reset(ptr); + } + + /// Changes the size of the view without affecting pointer or layout + CUTLASS_HOST_DEVICE + void resize(TensorCoord const &extent) { + this->extent_ = extent; + } + + /// Returns the extent of the view (the size along each logical dimension). + CUTLASS_HOST_DEVICE + TensorCoord const& extent() const { return extent_; } + + /// Returns the extent along a particular logical dimension. + CUTLASS_HOST_DEVICE + Index extent(int dim) const { return extent_.at(dim); } + + /// Returns the number of logical elements + CUTLASS_HOST_DEVICE + LongIndex size() const { + return extent_.product(); + } + + /// Determines whether a location is within a tensor + CUTLASS_HOST_DEVICE + bool contains(TensorCoord const& coord) const { + CUTLASS_PRAGMA_UNROLL + for (int dim = 0; dim < kRank; ++dim) { + if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { + return false; + } + } + return true; + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + TensorRef ref() const { + return TensorRef(this->data(), this->layout()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), this->layout()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent_); + } + + /// Returns a Tensor_view given location and size quantities + CUTLASS_HOST_DEVICE + TensorView subview( + TensorCoord extent, ///< extent of the resulting view + TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view + ) const { + + TensorView result(this->ref(), extent.clamp(extent_ - location)); + result.add_coord_offset(location); + return result; + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + size_t capacity() const { + return Base::layout().capacity(extent_); + } + + /// Returns a TensorView offset by a given amount + CUTLASS_HOST_DEVICE + TensorView operator+( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorView result(*this); + result.add_pointer_offset(this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView& operator+=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(this->offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView operator-( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorRef result(*this); + result.add_pointer_offset(-this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView& operator-=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(-this->offset(b)); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE TensorView make_TensorView( + Element *ptr, + Layout const &layout, + typename Layout::TensorCoord const &extent) { + + return TensorView(ptr, layout, extent); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..6b8f7b47c49d75f0b000d134031ea169fcc6d2a6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tensor_view_planar_complex.h @@ -0,0 +1,302 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing strides and a pointer to tensor data. + + TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, + it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from + data storage and is therefore lightweight and may be embedded in larger tensor objects or + memory structures. + + See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to + linear memory. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view.h" // cutlass::TensorView + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Data type of element stored within tensor + typename Element_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename Layout_ +> +class TensorViewPlanarComplex : public TensorRefPlanarComplex { + public: + + /// Base tensor reference + using Base = cutlass::TensorRefPlanarComplex; + + /// Mapping function from logical coordinate to internal n-D array + using Layout = Layout_; + + /// TensorRef pointing to constant memory + using ConstTensorRef = typename Base::ConstTensorRef; + + /// Underlying TensorRef type + using TensorRef = Base; + + /// Data type of individual access + using Element = Element_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Coordinate in storage n-D array + using Stride = typename Layout::Stride; + + /// TensorView pointing to constant memory + using ConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type const, + Layout>; + + /// TensorView pointing to non-constant memory + using NonConstTensorView = TensorViewPlanarComplex< + typename platform::remove_const::type, + Layout>; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + private: + + /// View extent + TensorCoord extent_; + + public: + + // + // Methods + // + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex(TensorCoord const &extent = TensorCoord()): extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + Element *ptr, ///< pointer to start of tensor + Layout const &layout, ///< layout object containing stride and mapping function + LongIndex imaginary_stride, ///< stride between real and imaginary part + TensorCoord const &extent ///< size of the view in logical coordinates + ): + Base(ptr, layout, imaginary_stride), extent_(extent) { + + } + + /// Constructs a TensorView object + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + TensorRef const &ref, ///< pointer and layout object referencing a tensor + TensorCoord const &extent ///< logical size of tensor + ): + Base(ref), extent_(extent) { + + } + + /// Converting constructor from TensorRef to non-constant data. + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex( + NonConstTensorView const &view ///< TensorView to non-const data + ): + Base(view), extent_(view.extent_) { } + + /// Updates the pointer and layout object + CUTLASS_HOST_DEVICE + void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride, TensorCoord size) { + Base::reset(ptr, layout, imaginary_stride); + this->resize(extent_); + } + + /// Changes the size of the view without affecting pointer or layout + CUTLASS_HOST_DEVICE + void resize(TensorCoord extent) { + this->extent_ = extent; + } + + /// Returns the extent of the view (the size along each logical dimension). + CUTLASS_HOST_DEVICE + TensorCoord const& extent() const { return extent_; } + + /// Returns the extent along a particular logical dimension. + CUTLASS_HOST_DEVICE + Index extent(int dim) const { return extent_.at(dim); } + + /// Determines whether a location is within a tensor + CUTLASS_HOST_DEVICE + bool contains(TensorCoord const& coord) const { + CUTLASS_PRAGMA_UNROLL + for (int dim = 0; dim < kRank; ++dim) { + if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) { + return false; + } + } + return true; + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + Base ref() const { + return Base(this->data(), this->layout(), this->imaginary_stride()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), this->layout()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent_); + } + + /// Returns a Tensor_view given location and size quantities + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex subview( + TensorCoord extent, ///< extent of the resulting view + TensorCoord const& location = TensorCoord() ///< resulting view's origin within the old view + ) const { + + TensorViewPlanarComplex result(this->ref(), extent.clamp(extent_ - location)); + result.add_coord_offset(location); + return result; + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + size_t capacity() const { + return Base::layout().capacity(extent_); + } + + /// Returns a TensorView offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator+( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorViewPlanarComplex result(*this); + result.add_pointer_offset(this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator+=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(this->offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex operator-( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) const { + + TensorRef result(*this); + result.add_pointer_offset(-this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorViewPlanarComplex& operator-=( + TensorCoord const& b ///< offset in the logical coordinate space of the tensor + ) { + + this->add_pointer_offset(-this->offset(b)); + return *this; + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_real() const { + return cutlass::TensorView(this->data(), this->layout(), extent_); + } + + /// TensorRef to real-valued tensor + CUTLASS_HOST_DEVICE + cutlass::TensorView view_imag() const { + return cutlass::TensorView(this->imaginary_data(), this->layout(), extent_); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a TensorRef, deducing types from arguments. +template < + typename Element, + typename Layout +> +CUTLASS_HOST_DEVICE TensorViewPlanarComplex make_TensorViewPlanarComplex( + Element *ptr, + Layout const &layout, + typename Layout::LongIndex imaginary_stride, + typename Layout::TensorCoord const &extent) { + + return TensorViewPlanarComplex(ptr, layout, imaginary_stride, extent); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h new file mode 100644 index 0000000000000000000000000000000000000000..7bc13e177f1d027fbba789367ac3f2ee5b748877 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/tfloat32.h @@ -0,0 +1,479 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines a proxy class for storing Tensor Float 32 data type. +*/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#else +#include +#include +#include +#include // std::memcpy +#endif + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tensor Float 32 data type +struct alignas(4) tfloat32_t { + + // + // Data members + // + + /// Storage type + uint32_t storage; + + // + // Methods + // + private: + CUTLASS_HOST_DEVICE + static uint32_t float_to_storage(float s) { + #if defined(__CUDA_ARCH__) + uint32_t result = reinterpret_cast(s); + #else + uint32_t result; + std::memcpy(&result, &s, sizeof(float)); + #endif + return result; + } + + public: + /// Constructs from an unsigned int + CUTLASS_HOST_DEVICE + static tfloat32_t bitcast(uint32_t x) { + tfloat32_t h; + h.storage = x; + return h; + } + + /// Emulated rounding is fast in device code + CUTLASS_HOST_DEVICE + static tfloat32_t round_half_ulp_truncate(float const &s) { + uint32_t x = float_to_storage(s); + + #if defined(__CUDA_ARCH__) + if (::isfinite(s)) { + x += 0x1000u; + } + #else + if (std::isfinite(s)) { + x += 0x1000u; + } + #endif + + return tfloat32_t::bitcast(x); + } + + tfloat32_t() = default; + + /// Floating-point conversion - round toward nearest even + CUTLASS_HOST_DEVICE + explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).raw()) { } + + // Conversion from double (this rounds twice) + CUTLASS_HOST_DEVICE + explicit tfloat32_t(double x): tfloat32_t(float(x)) { } + + /// Integer conversion - round toward zero + CUTLASS_HOST_DEVICE + explicit tfloat32_t(int x) { + float flt = static_cast(x); + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(flt); + #else + std::memcpy(&storage, &flt, sizeof(storage)); + #endif + } + + // Conversion to float + CUTLASS_HOST_DEVICE + operator float() const { + + // Conversions to IEEE single-precision requires clearing dont-care bits + // of the mantissa. + unsigned bits = (storage & ~0x1fffu); + + #if defined(__CUDA_ARCH__) + return reinterpret_cast(bits); + #else + float flt; + std::memcpy(&flt, &bits, sizeof(flt)); + return flt; + #endif + } + + /// Converts to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(float(*this)); + } + + /// Converts to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(float(*this)); + } + + /// Casts to bool + CUTLASS_HOST_DEVICE + explicit operator bool() const { + return (float(*this) != 0.0f); + } + + /// Obtains raw bits + CUTLASS_HOST_DEVICE + uint32_t raw() const { + return storage; + } + + /// Returns the sign bit + CUTLASS_HOST_DEVICE + bool signbit() const { + return ((raw() & 0x80000000) != 0); + } + + /// Returns the biased exponent + CUTLASS_HOST_DEVICE + int exponent_biased() const { + return int((raw() >> 23) & 0x0ff); + } + + /// Returns the unbiased exponent + CUTLASS_HOST_DEVICE + int exponent() const { + return exponent_biased() - 127; + } + + /// Returns the mantissa + CUTLASS_HOST_DEVICE + int mantissa() const { + return int(raw() & 0x7fffff); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool signbit(cutlass::tfloat32_t const& h) { + return h.signbit(); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) { + return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff); +} + +CUTLASS_HOST_DEVICE +bool isnan(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() == 0x0ff) && h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isfinite(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() != 0x0ff); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t nan_tf32(const char*) { + // NVIDIA canonical NaN + return cutlass::tfloat32_t::bitcast(0x7fffffff); +} + +CUTLASS_HOST_DEVICE +bool isinf(cutlass::tfloat32_t const& h) { + return (h.exponent_biased() == 0x0ff) && !h.mantissa(); +} + +CUTLASS_HOST_DEVICE +bool isnormal(cutlass::tfloat32_t const& h) { + return h.exponent_biased() && h.exponent_biased() != 0x0ff; +} + +CUTLASS_HOST_DEVICE +int fpclassify(cutlass::tfloat32_t const& h) { + int exp = h.exponent_biased(); + int mantissa = h.mantissa(); + if (exp == 0x0ff) { + if (mantissa) { + return FP_NAN; + } + else { + return FP_INFINITE; + } + } + else if (!exp) { + if (mantissa) { + return FP_SUBNORMAL; + } + else { + return FP_ZERO; + } + } + return FP_NORMAL; +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) { +#if defined(__CUDACC_RTC__) + return cutlass::tfloat32_t(sqrtf(float(h))); +#else + return cutlass::tfloat32_t(std::sqrt(float(h))); +#endif +} + +CUTLASS_HOST_DEVICE +tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) { + + uint32_t a_mag = (a.raw() & 0x7fffffff); + uint32_t b_sign = (b.raw() & 0x80000000); + uint32_t result = (a_mag | b_sign); + + return tfloat32_t::bitcast(result); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Standard Library operations and definitions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace std { + +#if !defined(__CUDACC_RTC__) +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; + static std::float_denorm_style const has_denorm = std::denorm_present; + static bool const has_denorm_loss = true; + static std::float_round_style const round_style = std::round_to_nearest; + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 19; + + /// Least positive value + static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); } + + /// Minimum finite value + static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); } + + /// Maximum finite value + static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); } + + /// Returns smallest finite value + static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); } + + /// Returns smallest finite value + static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); } + + /// Returns smallest finite value + static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); } +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace std + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Arithmetic operators +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTLASS_HOST_DEVICE +bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) == float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) != float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) < float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) <= float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) > float(rhs); +} + +CUTLASS_HOST_DEVICE +bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return float(lhs) >= float(rhs); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) + float(rhs)); +} + + +CUTLASS_HOST_DEVICE +tfloat32_t operator-(tfloat32_t const& lhs) { + return tfloat32_t::bitcast(0x80000000 ^ lhs.raw()); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) - float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) * float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) { + return tfloat32_t(float(lhs) / float(rhs)); +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) + float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) - float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) * float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) { + lhs = tfloat32_t(float(lhs) / float(rhs)); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator++(tfloat32_t & lhs) { + float tmp(lhs); + ++tmp; + lhs = tfloat32_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t& operator--(tfloat32_t & lhs) { + float tmp(lhs); + --tmp; + lhs = tfloat32_t(tmp); + return lhs; +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator++(tfloat32_t & lhs, int) { + tfloat32_t ret(lhs); + float tmp(lhs); + tmp++; + lhs = tfloat32_t(tmp); + return ret; +} + +CUTLASS_HOST_DEVICE +tfloat32_t operator--(tfloat32_t & lhs, int) { + tfloat32_t ret(lhs); + float tmp(lhs); + tmp--; + lhs = tfloat32_t(tmp); + return ret; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// User-defined literals +// + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t operator "" _tf32(long double x) { + return cutlass::tfloat32_t(float(x)); +} + +CUTLASS_HOST_DEVICE +cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) { + return cutlass::tfloat32_t(int(x)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..c338306132b9d9b2e42ff26759f7d1b3a7bc1ae3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/thread/matrix.h @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a matrix object intended for storing data in registers and operations within + a CUDA thread. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/matrix_coord.h" + +namespace cutlass { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Per-thread matrix object storing a packed matrix +template < + typename Element, + int Rows, + int Columns, + typename Layout = layout::RowMajor +> +class Matrix : public Array { +public: + + // Verify layout refers to a rank=2 matrix. + static_assert( + Layout::kRank == 2, + "Layout type must refer to a rank=2 matrix"); + + /// Base type + using Base = Array; + + /// Element type + using Element = Element_; + + /// Number of rows + static int const kRows = Rows; + + /// Number of columns + static int const kColumns = Columns; + + /// Layout within the array + using Layout = Layout_; + + /// Reference type to an element + using Reference = Element &; + + /// Logical rank of tensor index space + static int const kRank = 2; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Stride type + using Stride = typename Layout::Stride; + + /// TensorRef to matrix object + using TensorRef = TensorRef; + + /// TensorRef to constant matrix object + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// TensorRef to matrix object + using TensorView = TensorView; + + /// TensorRef to constant matrix object + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Diagonal vector + using Diagonal = Vector; + +private: + + +public: + + // + // Methods + // + + /// Returns the size of the object + CUTLASS_HOST_DEVICE + static MatrixCoord extent() { + return make_Coord(kRows, kColumns); + } + + /// Returns the layout object + CUTLASS_HOST_DEVICE + static Layout layout() { + return Layout::packed(extent()); + } + + /// Ctor + CUTLASS_HOST_DEVICE + Matrix() { } + + /// Ctor + CUTLASS_HOST_DEVICE + Matrix(Diagonal const &diag) { + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef(this->data(), layout()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(this->data(), layout()); + } + + /// Returns a TensorRef pointing to the first element of the tensor. + CUTLASS_HOST_DEVICE + TensorView view() { + return TensorView(ref(), extent()); + } + + /// Returns a TensorView to const data + CUTLASS_HOST_DEVICE + ConstTensorView const_view() const { + return ConstTensorView(const_ref(), extent()); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(MatrixCoord const& coord) const { + typename Base::size_type offset_(layout().offset(coord)); + return Base::at(offset_); + } + + /// Returns the number of scalar elements needed to store tensor. + CUTLASS_HOST_DEVICE + LongIndex capacity() const { + return LongIndex(Base::size()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Column vector defined as a matrix with exactly one column +template < + typename Element, + int Rows, + typename Layout = layout::ColumnMajor +> +using ColumnVector = Matrix; + +/// Row vector defined as a matrix with exactly one row +template < + typename Element, + int Columns, + typename Layout = layout::RowMajor +> +using RowVector = Matrix; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/trace.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/trace.h new file mode 100644 index 0000000000000000000000000000000000000000..803c72eca35a4cc3ee0712981942016f987f5b44 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/trace.h @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helpers for optionally tracing through code when debugging. + + This file is to be included after all other headers. +*/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Tracing options +#ifndef CUTLASS_DEBUG_TRACE_LEVEL +#define CUTLASS_DEBUG_TRACE_LEVEL 0 +#endif + +#if CUTLASS_DEBUG_TRACE_LEVEL +#include +#include "cutlass/core_io.h" +#if defined(__CUDA_ARCH__) +#define CUTLASS_TRACE_HOST(x) +#else +#define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } +#endif +#else +#define CUTLASS_TRACE_HOST(x) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41bc4786c7a8d148340a23bf1ce1db66f04f10b4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp @@ -0,0 +1,754 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing how threads are mapped to a given tile. +*/ + +#pragma once + +#include "cute/arch/mma_sm90_gmma.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +using namespace cute; + +template +constexpr auto +gmma_smem_transpose_or_passthrough() { + if constexpr (Transpose) { + if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (cute::is_same_v, SmemLayoutAtom>) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported Layout_SW_Atom for B SMEM transposition"); + } + } + else { + return SmemLayoutAtom{}; + } +} + +template +constexpr auto +use_universal_transposition() { + if constexpr (sizeof(ElementType) == 1) { + return !cute::is_same_v, SmemCopyAtom>; + } + else if constexpr (sizeof(ElementType) == 4){ + // Only universal transposition can handle SW64 and Non swizzle SMEM layout + if constexpr (cute::is_same_v, SmemCopyAtom> || + cute::is_same_v, SmemCopyAtom>) { + return true; + } + else { + return false; + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ElementType for B SMEM transposition"); + } +} + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class NoTranspositionOperandB { +public: + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + constexpr CUTLASS_HOST_DEVICE + NoTranspositionOperandB( + int, + int, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const&, + TensorTransposedSmemB const&, + int, int) { } + + CUTLASS_DEVICE void synchronize(int) { } + + CUTLASS_DEVICE void synchronize() { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const&, + TensorTransposedSmemB const&, + int) { } +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class UniversalTranspositionOperandB { +public: + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + constexpr CUTLASS_HOST_DEVICE + UniversalTranspositionOperandB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) { + if (current_step > 0) { + return; + } + + constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static_assert(NumMathWarpGroup == 1 || + (!detail::use_universal_transposition() && NumMathWarpGroup == 2), + "Wrong math warp group number for TransposeB"); + constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + + constexpr int BytesPerSmemSwizzleUnit = 16; + constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Universal transposition, need warp_group sync between load and store. + /// The number of reg used depends on the input elementB. + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /* + In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location. + In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements: + K + ------------ + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step 0 + | W0 W1 W2 W3 --- + .... + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step n + | W0 W1 W2 W3 --- + */ + static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout."); + constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int{}, Int{})); + + // Get copy tile and partition to each thread + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpgroupThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy."); + + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx); + Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) + Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) + + // Divide partitioned tile to limit register usage + constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB)); + static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM."); + + Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape); + Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape); + auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{})); + + CUTLASS_PRAGMA_NO_UNROLL + for (int step = 0; step < CopySteps; ++step) { + copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment); + + // Make sure all elements are read before being overwritten + __syncthreads(); + + copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step)); + } + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step == 0) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + } + + CUTLASS_DEVICE void synchronize() { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + + this->operator()(sB, gmma_sB, read_stage, 0); + synchronize(); + + } + +private: + const int warp_idx; + const int warp_group_thread_idx; +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class AsyncTranspositionOperandB { +public: + + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + static constexpr int Steps = 2; + static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + static_assert(NumMathWarpGroup <= 2, + "Wrong math warp group number for TransposeB"); + static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; + + static constexpr int BytesPerSmemSwizzleUnit = 16; + static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; + static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); + + static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); + static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. + static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; + static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; + static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. + static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, + static constexpr int NumBitsPerStep = 3; + static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) + static constexpr int NumBitsPerWarp = 12; + // Number of warp_group_tiles + static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + "Copy size must evenly divide SMEM tile."); + static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + + static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK, + "Need to be able to transpose first k-block in the first step"); + + constexpr CUTLASS_HOST_DEVICE + AsyncTranspositionOperandB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) + , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) + { + if (current_step >= StepsPerWarpGroup) { + return; + } + + static constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// A warp group uses 2 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. + /// In each step, one warp would hold two warp_tiles. + /// Step 0: Step 1: + /// W0 W1 W2 W3 -- -- -- -- + /// W1 W0 -- -- -- -- W3 W2 + /// W2 -- -- -- -- W3 W0 W1 + /// W3 -- -- -- -- W2 W1 W0 + /// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// + /// Fully static coord LUT to avoid extra register use. + /// [warp_id][step][warp_tile][n / k] + /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 + /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 + /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 + /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 + /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 + /// + /// Encoding the coord of warp tile0 into two int64_t values. + /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. + /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. + /// The 2-step transposition and the 8-step transposition share the same encoding. + /// + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Divide entire SMEM to multiple warp_tiles + constexpr auto WarpTileShape = make_shape(Int(), Int()); + Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); + Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); + + // Get copy tile + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + + static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx + + // Construct fragments for transposition + Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); + decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { + make_fragment_like(tmp_tCsB), + make_fragment_like(tmp_tCsB) + }; + + [[maybe_unused]] int step = current_step * NumMathWarpGroup; + if constexpr (NumMathWarpGroup == 2) { + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + step += warp_idx / (NumWarpsPerWarpGroup * 2); + } + + int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT >> (NumBitsPerStep * current_step); + int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT >> (NumBitsPerStep * current_step); + + if constexpr (NumMathWarpGroup == 2) { + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + } + + // decoding the warp tile coord. + int warp_tile0_n, warp_tile0_k; + if constexpr (StepsPerWarpGroup <= NumStepsEncoded) { + warp_tile0_n = tmp_warp_tile_n_coord_LUT & MaskPerStep; + warp_tile0_k = tmp_warp_tile_k_coord_LUT & MaskPerStep; + } else { + warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; + warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; + } + + int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; + int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; + + CUTLASS_PRAGMA_UNROLL + for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { + + static_assert(TilesPerWarp == 2); + + // [warp_tile][n/k] + const int warp_tile_coord[TilesPerWarp][2] = { + // n k + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 + }; + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB = sB_thr_copy.partition_S( + flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); + } + + // Make sure elements in two 8x8 warp tiles are all consumed + __syncwarp(); + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB_transposed = sB_thr_copy.partition_D( + flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); + } + + } // loop warp_group_tile + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step < StepsPerWarpGroup) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + } + + CUTLASS_DEVICE void synchronize() { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < StepsPerWarpGroup; ++i) { + this->operator()(sB, gmma_sB, read_stage, i); + } + synchronize(); + + } +private: + const int warp_idx; + const int warp_group_thread_idx; + const int warp_idx_in_warp_group; + const int current_warp_tile_n_coord_LUT; + const int current_warp_tile_k_coord_LUT; +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class AsyncTranspositionOperandB_1BElementB { +public: + + static_assert(sizeof(ElementB_) == 1); + + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + static constexpr int Steps = 8; + static constexpr int NumMathWarpGroup = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + static_assert(NumMathWarpGroup <= 2, + "Wrong math warp group number for TransposeB"); + static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; + + static constexpr int BytesPerSmemSwizzleUnit = 16; + static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; + static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); + + static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." ); + static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. + static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; + static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; + static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. + static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, + static constexpr int NumBitsPerStep = 3; + static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) + static constexpr int NumBitsPerWarp = 12; + // Number of warp_group_tiles + static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + "Copy size must evenly divide SMEM tile."); + static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + + constexpr CUTLASS_HOST_DEVICE + AsyncTranspositionOperandB_1BElementB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) + , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) + { + if (current_step > 0) { + return; + } + + constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. + /// Divide a warp_group_tile into 8x8 warp_tiles to further reduce the reg usage. + /// Step 0: Step 1: Step 2: Step 3: + /// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W1 W0 -- -- -- -- -- -- -- -- W3 W2 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W2 -- -- -- -- -- -- -- -- W3 W0 W1 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// W3 -- -- -- -- -- -- -- -- W2 W1 W0 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W1 W0 -- -- -- -- -- -- -- -- W3 W2 + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W3 W0 W1 + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W2 W1 W0 + /// + /// Step 4: Step 5: Step 6: Step 7: + /// -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 -- -- -- -- -- -- -- -- + /// -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- W0 W1 W2 W3 + /// W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- -- -- -- -- W0 -- -- -- -- + /// W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- -- -- -- -- W1 -- -- -- -- + /// W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- -- -- -- -- W2 -- -- -- -- + /// W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- -- -- -- -- W3 -- -- -- -- + /// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// + /// Fully static coord LUT to avoid extra register use. + /// [warp_id][step][warp_tile][n / k] + /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 + /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 + /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 + /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 + /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 + /// + /// Encoding the coord of warp tile0 into two int64_t values. + /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. + /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. + /// The 2-step transposition and the 8-step transposition share the same encoding. + /// + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Divide entire SMEM to multiple warp_tiles + constexpr auto WarpTileShape = make_shape(Int(), Int()); + Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); + Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); + + // Get copy tile + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx + + // Construct fragments for transposition + Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); + decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { + make_fragment_like(tmp_tCsB), + make_fragment_like(tmp_tCsB) + }; + + CUTLASS_PRAGMA_NO_UNROLL + for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { + int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT; + int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT; + constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + + if constexpr (NumMathWarpGroup == 2) { + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int step_per_warp_group = 0; step_per_warp_group < StepsPerWarpGroup; ++step_per_warp_group) { + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + int step = step_per_warp_group * NumMathWarpGroup + warp_idx / (NumWarpsPerWarpGroup * 2); + // decoding the warp tile coord. + int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; + int warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; + int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; + int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; + + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep; + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep; + + static_assert(TilesPerWarp == 2); + + // [warp_tile][n/k] + const int warp_tile_coord[TilesPerWarp][2] = { + // n k + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 + }; + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB = sB_thr_copy.partition_S( + flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); + } + + // Make sure elements in two 8x8 warp tiles are all consumed + __syncwarp(); + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB_transposed = sB_thr_copy.partition_D( + flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); + } + } // lock step + } // loop warp_group_tile + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step == 0) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + } + + CUTLASS_DEVICE void synchronize() { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::TransposeBarrier); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + this->operator()(sB, gmma_sB, read_stage, 0); + synchronize(); + } + +private: + const int warp_idx; + const int warp_group_thread_idx; + const int warp_idx_in_warp_group; + const int current_warp_tile_n_coord_LUT; + const int current_warp_tile_k_coord_LUT; +}; + + +template< + class TiledMma, + class SmemLayoutB, + class SmemLayoutAtomB, + class ElementB, + bool TransposeB +> +constexpr CUTLASS_HOST_DEVICE +auto +make_transpose_operand_b( + int warp_idx, + int warp_group_thread_idx, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB, + cute::bool_constant) +{ + if constexpr (!TransposeB) { + return NoTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else if constexpr (use_universal_transposition()) { + return UniversalTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else if constexpr (sizeof(ElementB) == 1) { + return AsyncTranspositionOperandB_1BElementB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else { + return AsyncTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } +} + +}; // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..265d2fe4367180b0c5c76f22df7d00f01dfb170e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/device/transform_universal_adapter.hpp @@ -0,0 +1,303 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Transform Kernel Universal adapter +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cutlass/kernel_launch.h" +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::transform::device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TransformUniversalAdapter +{ +public: + using TransformKernel = GetUnderlyingKernel_t; + using Arguments = typename TransformKernel::Arguments; + using Params = typename TransformKernel::Params; + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + return TransformKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += TransformKernel::get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = TransformKernel::to_underlying_arguments(args, workspace); + return TransformKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return TransformKernel::get_grid_shape(params); + } + + + /// Initializes GEMM state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("TransformUniversalAdapter::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null") + << ", EnableCudaHostAdapter: " << (kEnableCudaHostAdapter ? "True" : "false")); + + // Initialize the workspace + Status status = TransformKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + // Initialize the Params structure + params_ = TransformKernel::to_underlying_arguments(args, workspace); + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // + // Account for dynamic smem capacity if needed + // + int smem_size = TransformKernel::SharedStorageSize; + + CUTLASS_ASSERT(cuda_adapter == nullptr); + + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + static Status + run(Params& params, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0, + bool launch_with_pdl = false) { + CUTLASS_TRACE_HOST("TransformUniversalAdapter::run()"); + dim3 const block = TransformKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = TransformKernel::SharedStorageSize; + + Status launch_result{ Status::kSuccess }; + // Use extended launch API only for mainloops that use it + if constexpr (TransformKernel::ArchTag::kMinComputeCapability >= 90) { + // Currently only support 1x1x1 for transform kernel. + dim3 const cluster = {1,1,1}; + void* kernel_params[] = {¶ms}; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + if (launch_with_pdl) { + CUTLASS_TRACE_HOST( + "TransformUniversalAdapter::run() does not support launching with PDL and a custom cuda adapter."); + return Status::kErrorInternal; + } + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + kernel_index); + CUTLASS_TRACE_HOST("Kernel Launch Result" << cutlassGetStatusString(launch_result)); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + if constexpr (TransformKernel::ArchTag::kMinComputeCapability == 90) { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params, launch_with_pdl); + } + } + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + cutlass::kernel_launch(grid, block, smem_size, stream, params, launch_with_pdl); + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else if (cudaSuccess != result) { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cudaGetErrorString(result)); + } + else if (Status::kSuccess != launch_result) { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << cutlassGetStatusString(launch_result)); + } + return Status::kErrorInternal; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0, + bool launch_with_pdl = false + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, kernel_index, launch_with_pdl); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(args, workspace, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run( + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, bool launch_with_pdl = false) { + return run(params_, stream, cuda_adapter, 0 /*kernel_index*/, launch_with_pdl); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::transform::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c9d7589a309ebe6276bb564ac76a9e036bdd50a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/filter_format_transformer.hpp @@ -0,0 +1,223 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief Convolution filter format transformation kernel. +*/ + +#pragma once + +#include +#include + +#include "cutlass/coord.h" +#include "cutlass/arch/arch.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/int_tuple.hpp" +#include "cute/tensor.hpp" +#include "cute/config.hpp" + +namespace cutlass::transform::kernel { + +using namespace cute; + +enum class FilterFormat { + CKTRS, + CTRSK, + KTRSC +}; + +template < + FilterFormat SrcFormat, + FilterFormat DstFormat, + int NumDimensions, + class Element_, + int AlignmentBytes = 16 +> +struct ConvFilterFormatTransformer { + + using Element = Element_; + static_assert(SrcFormat == FilterFormat::CKTRS, "Currently only source format of CKTRS is supported"); + static_assert(DstFormat == FilterFormat::CTRSK || DstFormat == FilterFormat::KTRSC, "Currently only destination format of CTRSK/KTRSC is supported"); + static_assert(AlignmentBytes > 0 && AlignmentBytes % static_cast(sizeof(Element)) == 0, "Invalid alignment setting"); + + // In ktrsc order. + using FilterExtent = array; + + // Default cta tile shape: 32x32 + static constexpr auto CTATileShape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); + // Default thread layout: (4, 32) + static constexpr auto ThreadLayout = make_layout(make_shape(Int<4>{}, Int<32>{})); + + static constexpr uint32_t MaxThreadsPerBlock = 128; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + using ArchTag = arch::Sm90; + + // Default ctor + CUTLASS_HOST_DEVICE + ConvFilterFormatTransformer() {} + + struct Arguments { + const void *src_ptr; + void *dst_ptr; + FilterExtent filter_extent; + }; + + struct Params { + using TensorSrc = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(take<0,NumDimensions>(FilterExtent{})))); + using TensorDst = decltype(make_tensor(make_gmem_ptr(recast_ptr(nullptr)), make_layout(make_shape(int32_t(0), int32_t(0))))); + + TensorSrc src; + TensorDst dst; + }; + + struct SharedStorage { + /* empty, no smem needed */ + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static Status + can_implement(Arguments const& args) { + bool implementable = true; + // alignment rule + { + int contiguous_dim = DstFormat == FilterFormat::CTRSK ? args.filter_extent[0] : args.filter_extent[NumDimensions - 1]; + int align_element = AlignmentBytes / static_cast(sizeof(Element)); + + implementable &= (contiguous_dim % align_element == 0); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Alignment setting is invalid.\n"); + return Status::kInvalid; + } + } + + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static dim3 + get_block_shape() { + return dim3(size(shape(ThreadLayout)), 1, 1); + } + + static dim3 + get_grid_shape(Params const& params) { + auto dim_m = ceil_div(size<0>(shape(params.dst)), get<0>(CTATileShape)); + auto dim_n = ceil_div(size<1>(shape(params.dst)), get<1>(CTATileShape)); + + return dim3(dim_m, dim_n, 1); + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + auto k = args.filter_extent[0]; + auto c = args.filter_extent[NumDimensions - 1]; + auto srt = reverse(take<1,NumDimensions - 1>(args.filter_extent)); + + // source shape (s,r,t,k,c) + auto shape_src = flatten(make_shape(srt, k, c)); + auto shape_dst = DstFormat == FilterFormat::CTRSK ? make_shape(k, c * product(srt)) : make_shape(c, k * product(srt)); + + auto src = make_tensor(make_gmem_ptr(recast_ptr(args.src_ptr)), make_layout(shape_src)); + auto dst = make_tensor(make_gmem_ptr(recast_ptr(args.dst_ptr)), make_layout(shape_dst)); + + return Params{src, dst}; + } + + CUTLASS_DEVICE + void operator()(Params const& params, char *smem_buf) { + // Tile the input tensor into blocks + auto block_coord = make_coord(blockIdx.x, blockIdx.y); + auto block_shape = make_shape(Int<4 * AlignmentBytes / static_cast(sizeof(Element))>{}, Int<32>{}); + // Default thread layout: (4, 32) + auto thread_layout = make_layout(make_shape(Int<4>{}, Int<32>{})); + auto vec_layout = make_layout(make_shape(Int(sizeof(Element))>{}, Int<1>{})); + + Tensor tile_D = local_tile(params.dst, block_shape, block_coord); + + // Construct tiled copy + using AccessType = cutlass::AlignedArray; + using Atom = Copy_Atom, Element>; + + auto tiled_copy = make_tiled_copy(Atom{}, thread_layout, vec_layout); + auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + Tensor thr_tile_D = thr_copy.partition_D(tile_D); + + // shape (s, r, t) + auto shape_trs = take<0, NumDimensions - 2>(shape(params.src)); + // strided_c = c for format CTRSK, strided_c = k for format KTRSC + auto strided_c = DstFormat == FilterFormat::CTRSK ? get(shape(params.src)) : get(shape(params.src)); + // shape (s, r, t, c) for format CTRSK and shape (s, r, t, k) for format KTRSC + auto shape_ctrs = append(shape_trs, strided_c); + auto srtc_coord = idx2crd(int(blockIdx.y * get<1>(block_shape) + threadIdx.x / size<0>(thread_layout)), shape_ctrs); + // index of k for format CTRSK and index of c for format KTRSC + auto n_layout = make_layout(make_shape(gridDim.x, size<0>(thread_layout)), make_stride(size<0>(block_shape), size<0>(vec_layout))); + int n_idx = n_layout(make_coord(blockIdx.x, threadIdx.x % size<0>(thread_layout))); + + // Fragment to load from S and store to D + auto frag = make_fragment_like(thr_tile_D); + // Predicate tensor. + Tensor thr_tile_P = make_tensor(shape(thr_tile_D)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(frag); ++i) { + auto srt_coord = take<0, NumDimensions - 2>(srtc_coord); + auto kc_coord = DstFormat == FilterFormat::CTRSK ? + make_coord(n_idx+i, get(srtc_coord)) : + make_coord(get(srtc_coord), n_idx+i); + auto coord = flatten(make_coord(srt_coord, kc_coord)); + thr_tile_P(i) = elem_less(coord, shape(params.src)); + if (thr_tile_P(i)) { + frag(i) = params.src(coord); + } + } + + // Copy from RMEM to GMEM + copy_if(tiled_copy, thr_tile_P, frag, thr_tile_D); + } +}; + +} // namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..577c68c341c5c7a3d26c7209b2c40e309c65abee --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp @@ -0,0 +1,603 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils specific for SM90 structure sparse kernels +*/ + +#pragma once + +#include "cute/container/bit_field.hpp" // cute::bit_field +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v, cute::uint_bit_t +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor +#include "cute/algorithm/cooperative_copy.hpp" // cute::cooperative_copy +#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes +#include "cutlass/numeric_types.h" // cutlass::has_negative_zero_v +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter + +namespace cutlass::transform::kernel { + +using namespace cute; + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class SM90StructuredSparseCompressor { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + // * EltA + using ElementA = ElementA_; + using ElementAUint = cute::uint_bit_t>; + using ElementAMma = typename SparseConfig::ElementAMma; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementAMmaRawUnit = cute::uint_bit_t>; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + using ElementAUintCompressed = cute::sparse_elem; + using LayoutATag = LayoutATag_; + using LayoutA = LayoutATag; + using StrideA = cutlass::gemm::TagToStrideA_t; + + // * EltE + using ElementEMma = typename SparseConfig::ElementEMma; + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + // Data Type for storing one chunk's metadata + static constexpr int ElementEBitsPerChunk = typename SparseConfig::ElementEBitsPerChunk{}; + CUTE_STATIC_ASSERT(ElementEBitsPerChunk == 4, "ElementEBitsPerChunk is 4 for SM90"); + using ElementEChunk = cute::uint_bit_t; + CUTE_STATIC_ASSERT(cute::is_same_v, "ElementEChunk is uint4_t for SM90"); + using ElementESparsityPerChunk = Int / ElementEBitsPerChunk)>; + + // AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + // * Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + // Required by `device_kernel` + static constexpr int MaxThreadsPerBlock = TensorEAtomM{}; + static constexpr int MinBlocksPerMultiprocessor = 1; + using ArchTag = arch::Sm90; + + struct SharedStorage { + ElementEMma cEsE[cute::size(TensorEAtom{})]; + ElementAUintCompressed cACsAC[cute::size(TensorEAtom{})]; + ElementAUint cAsA[cute::size(TensorEAtom{})]; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct TransformArguments { + void const* ptr_A{nullptr}; + StrideA dA{}; + void* ptr_ACompress{nullptr}; + void* ptr_E{nullptr}; + }; + + using TransformParams = TransformArguments; + + struct Arguments { + ProblemShape problem_shape{}; + TransformArguments transform{}; + KernelHardwareInfo hw_info{}; + }; + + struct Params { + ProblemShape problem_shape{}; + TransformParams transform{}; + KernelHardwareInfo hw_info{}; + void* workspace = nullptr; + }; + +public: + static Params + to_underlying_arguments(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::to_underlying_arguments()"); + return Params{{args.problem_shape}, + {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, + {args.hw_info}, + workspace}; + } + + static Status + can_implement(Arguments const& args) { + auto [M, N, K, L] = args.problem_shape; + if (K % LogicalElemsAPerChunk != 0) { + CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size"); + return Status::kErrorInvalidProblem; + } + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::can_implement() (True)"); + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + CUTLASS_UNUSED(args); + // Backward compatible with host compressor + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_workspace_size() (" << SharedStorageSize << ")"); + return SharedStorageSize; + } + + static Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + CUTLASS_UNUSED(args); + CUTLASS_UNUSED(workspace); + CUTLASS_UNUSED(stream); + CUTLASS_UNUSED(cuda_adapter); + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::initialize_workspace()"); + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + constexpr int MaxAlignmentM = cutlass::const_max(TensorEAlignmentM, TensorAAlignmentM); + constexpr int MaxAlignmentK = cutlass::const_max(TensorEAlignmentK, TensorAAlignmentK); + const auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape; + + const int GemmMAlignedMax = cutlass::round_up(GemmM, MaxAlignmentM); + const int GemmKAlignedMax = cutlass::round_up(GemmK, MaxAlignmentK); + + const int gridDim_X = cutlass::ceil_div(GemmMAlignedMax, TensorEAtomM{}); + const int gridDim_Y = cutlass::ceil_div(GemmKAlignedMax, TensorEAtomK{}); + const int gridDim_Z = GemmL; + + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_grid_shape() (" + << gridDim_X << ", " + << gridDim_Y << ", " + << gridDim_Z << ")"); + return dim3(gridDim_X, gridDim_Y, gridDim_Z); + } + + static dim3 + get_block_shape() { + CUTLASS_TRACE_HOST("SM90StructuredSparseCompressor::get_block_shape() (" + << MaxThreadsPerBlock << ", " + << 1 << ", " + << 1 << ")"); + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTE_DEVICE + void + operator()(Params params, void* smem_buf = nullptr) { + run(params, smem_buf); + } + + CUTE_DEVICE + static void + run(Params params, void* smem_buf = nullptr) { + structure_sparse_compress(params, smem_buf); + } + +private: + + struct MetadataOneChunk1to2 { + + CUTE_DEVICE + void set_metadata_bits(int elt_log_idx, int elt_phy_idx) { + auto metadata_bits = [&]() -> uint8_t { + CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 2); + switch (elt_log_idx) { + case 0: + return 0b0100; + case 1: + return 0b1110; + default: + CUTE_GCC_UNREACHABLE; + } + }; + + storage_ |= (metadata_bits() << (4 * elt_phy_idx)); + } + + + CUTE_DEVICE + ElementEChunk storage() const { + return ElementEChunk{storage_}; + } + + private: + uint8_t storage_ = 0b0000; + }; + + struct MetadataOneChunk2to4{ + + CUTE_DEVICE + void set_metadata_bits(int elt_log_idx, int elt_phy_idx) { + auto metadata_bits = [&]() -> uint8_t { + CUTLASS_ASSERT(elt_log_idx >= 0 && elt_log_idx < 4); + switch (elt_log_idx) { + case 0: + return 0b00; + case 1: + return 0b01; + case 2: + return 0b10; + case 3: + return 0b11; + default: + CUTLASS_ASSERT(false); + CUTE_GCC_UNREACHABLE; + return 0b00; + } + }; + + storage_ |= (metadata_bits() << (2 * elt_phy_idx)); + } + + CUTE_DEVICE + ElementEChunk storage() const { + return ElementEChunk{storage_}; + } + + private: + uint8_t storage_ = 0b0000; + }; + + using MetadataOneChunk = cute::conditional_t; + +private: + + CUTE_DEVICE + static void + structure_sparse_compress(Params params, void* smem_buf) { + // * Input Params + auto [GemmM, GemmN, GemmK, GemmL] = params.problem_shape; + auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + [[maybe_unused]] const int gridDim_X = gridDim.x; + [[maybe_unused]] const int gridDim_Y = gridDim.y; + [[maybe_unused]] const int gridDim_Z = gridDim.z; + [[maybe_unused]] const int blockDim_X = blockDim.x; + + // * Global Tensor Layout + const cute::Layout layout_gA = make_layout(make_shape(GemmM, GemmK, GemmL), dA); + const cute::Layout layout_gAC = SparseConfig::fill_layoutA(params.problem_shape); + const cute::Layout layout_gE = SparseConfig::fill_layoutE(params.problem_shape); + + // * Construct Global Tensor + const cute::Tensor gA = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_A)), layout_gA); + cute::Tensor gAC_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_ACompress)), layout_gAC ); + cute::Tensor gAC = cute::recast(gAC_sparse); + cute::Tensor gE_sparse = make_tensor(make_gmem_ptr(cute::recast_ptr(ptr_E)), layout_gE); + cute::Tensor gE = cute::recast(gE_sparse); + + // * CTA Tensor Layout + using cAsA_layout_row = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutRight{})); + using cAsA_layout_col = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{}), LayoutLeft{})); + using cAsA_layout = cute::conditional_t, cAsA_layout_row, cAsA_layout_col>; + using cACsAC_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementASparsity{}), LayoutRight{})); + using cEsE_layout = decltype(make_layout(make_shape(TensorEAtomM{}, TensorEAtomK{} / ElementEMmaSparsity{}), LayoutRight{})); + + CUTE_STATIC_ASSERT(cute::is_static_v, "TensorEAtom needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "cAsA_layout needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "cACsAC_layout needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "cEsE_layout needs to be static"); + + const int blockIdx_X = blockIdx.x; + const int blockIdx_Y = blockIdx.y; + const int blockIdx_Z = blockIdx.z; + const int threadIdx_X = threadIdx.x; + + // * Construct CTA Tensor + const auto cta_coord = make_coord(blockIdx_X, blockIdx_Y, blockIdx_Z); + cute::Tensor cAgA = cute::recast(local_tile(gA, shape(cAsA_layout{}), cta_coord)); + cute::Tensor cACgAC = cute::recast(local_tile(gAC, shape(cACsAC_layout{}), cta_coord)); + cute::Tensor cEgE = local_tile(gE, shape(cEsE_layout{}), cta_coord); + + cute::Tensor cAsA = cute::recast(make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cAsA)), cAsA_layout{})); + cute::Tensor cACsAC = cute::recast(make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cACsAC)), cACsAC_layout{})); + cute::Tensor cEsE = make_tensor(make_smem_ptr(cute::recast_ptr(shared_storage.cEsE)), cEsE_layout{}); + cute::Tensor cEsE_chunk = cute::recast(cEsE); + + // * Handle in unit of Chunk when compress + using OneChunkSizeA = Int; + using OneChunkSizeAC = Int; + using OneChunkSizeE = Int; + using NumOneChunkK = Int; + + cute::Tensor cAsA_log_chunk = logical_divide(cAsA, make_shape(_, OneChunkSizeA{})); + cute::Tensor cACsAC_log_chunk = logical_divide(cACsAC, make_shape(_, OneChunkSizeAC{})); + cute::Tensor cEsE_log_chunk = logical_divide(cEsE_chunk, make_shape(_, OneChunkSizeE{})); + + // * Corner Case Handle + const auto GemmM_within_Cta = (GemmM - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmM - blockIdx_X * TensorEAtomM{}; + const auto GemmK_within_Cta = ( (GemmK - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmK - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw; + const auto GemmK_NumOneChunk_within_Cta = GemmK_within_Cta / LogicalElemsAMmaRawPerChunk; + + const auto GemmMAlignedAC = cutlass::round_up(GemmM, TensorAAlignmentM); + const auto GemmKAlignedAC = cutlass::round_up(GemmK, TensorAAlignmentK); + const auto GemmMAlignedAC_within_Cta = (GemmMAlignedAC - blockIdx_X * TensorEAtomM{} > TensorEAtomM{}) ? TensorEAtomM{} : GemmMAlignedAC - blockIdx_X * TensorEAtomM{}; + const auto GemmKAlignedAC_within_Cta = ( (GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} > TensorEAtomK{}) ? TensorEAtomK{} : GemmKAlignedAC - blockIdx_Y * TensorEAtomK{} ) / ElemsARawPerElementAMmaRaw; + + // * Clear CTA Smem Tensor + cooperative_clear(threadIdx_X, cACsAC); + cooperative_clear(threadIdx_X, cEsE); + + // * Input CTA Tensor G to S + if (GemmM_within_Cta == TensorEAtomM{} && GemmK_within_Cta == TensorEAtomK{}) { + copy_vec_pred(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta); + } + else { + copy_vec_pred(cAgA, cAsA, threadIdx_X, GemmM_within_Cta, GemmK_within_Cta); + } + + // Construct a sign bit mask for handling negative zeros + ElementAMmaRawUnit sign_mask = ElementAMmaRawUnit{ 0 }; + if constexpr (has_negative_zero_v) { + ElementAMmaRawUnit one_sign_mask = static_cast(~(ElementAMmaRawUnit{ 1 } << (cute::sizeof_bits_v - 1))); + for (int i = 0; i < sizeof(ElementAMmaRawUnit) / sizeof(ElementAUint); ++i) { + sign_mask = static_cast((int32_t)sign_mask | (int32_t)one_sign_mask << (i * cute::sizeof_bits_v)); + } + } + + // * Compress + // cACsAC is always row major order + // TensorEAtomM threads perform the compression, each thread compress one row + const int row_i = threadIdx_X; + if (row_i < GemmM_within_Cta) { + + CUTE_UNROLL + for (int col_chunk_i = 0; col_chunk_i < NumOneChunkK{}; ++col_chunk_i) { + if (col_chunk_i < GemmK_NumOneChunk_within_Cta) { + // Compress is handled in unit of ElementAMmaRawUnit + cute::Tensor tAsA = cAsA_log_chunk(row_i, make_coord(_, col_chunk_i)); + cute::Tensor tACsAC = cACsAC_log_chunk(row_i, make_coord(_, col_chunk_i)); + cute::Tensor tEsE = cEsE_log_chunk(row_i, make_coord(_, col_chunk_i)); + + int non_zero_cnt = 0; + // None zero element indx + // e.g. + // 2:4 sparsity [x 0 0 x] + // non_zero_elt_log_idx = [0, 3] + int non_zero_elt_log_idx[OneChunkSizeAC{}] = { 0 }; + + // * Find None Zero Element Idx within Chunk + CUTE_UNROLL + for (int elt_log_idx = 0; elt_log_idx < OneChunkSizeA{}; ++elt_log_idx) { + ElementAMmaRawUnit elem_A = tAsA[elt_log_idx]; + + // Handle negative 0 + ElementAMmaRawUnit masked_elem_A = elem_A; + if constexpr (has_negative_zero_v) { + masked_elem_A = elem_A & sign_mask; + } + + if (masked_elem_A != ElementAMmaRawUnit{0}) { + non_zero_elt_log_idx[non_zero_cnt] = elt_log_idx; + tACsAC[non_zero_cnt] = elem_A; + non_zero_cnt++; + } + } + + // * Corner Case for 2:4 sparsity + if constexpr (cute::sizeof_bits_v < 32) { + // i.e. [0 0 0 x] -> [(0) 0 0 x] + if (non_zero_cnt == 1 && non_zero_elt_log_idx[0] == 3) { + tACsAC[1] = tACsAC[0]; + tACsAC[0] = ElementAMmaRawUnit{0}; + non_zero_elt_log_idx[0] = 0; + non_zero_elt_log_idx[1] = 3; + } + // i.e. [0 0 x 0] -> [0 0 x (0)] + // i.e. [0 x 0 0] -> [0 x 0 (0)] + // i.e. [x 0 0 0] -> [x 0 0 (0)] + else if (non_zero_cnt == 1) { + tACsAC[1] = ElementAMmaRawUnit{0}; + non_zero_elt_log_idx[1] = 3; + } + } + + // * Set Metadata Bits + MetadataOneChunk metadata_one_chunk; + CUTE_UNROLL + for (int elt_phy_idx = 0; elt_phy_idx < OneChunkSizeAC{}; elt_phy_idx++) { + metadata_one_chunk.set_metadata_bits(non_zero_elt_log_idx[elt_phy_idx], elt_phy_idx); + } + tEsE[0] = metadata_one_chunk.storage(); + + } + else { + break; + } + } + } + + // * Sync after Compress + __syncthreads(); + + // * Output Cta Tensor S to G + if (GemmM_within_Cta > 0 && GemmK_within_Cta > 0) { + constexpr int MaxVecBits = 128; // STG.128 + cute::cooperative_copy(threadIdx_X, cEsE, cEgE); + } + + if (GemmMAlignedAC_within_Cta == TensorEAtomM{} && GemmKAlignedAC_within_Cta == TensorEAtomK{}) { + copy_vec_pred(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value)); + } + else { + copy_vec_pred(cACsAC, cACgAC, threadIdx_X, GemmMAlignedAC_within_Cta, (GemmKAlignedAC_within_Cta / ElementASparsity::value)); + } + + } // end of structure_sparse_compress() + + template + CUTE_DEVICE + static void + cooperative_clear( + uint32_t const& tid, + TensorSrc dSrc) { + + auto dSrctSrc = local_partition(dSrc, make_layout(make_shape(NumThreads, _1{})), tid); + cute::clear(dSrctSrc); + + // Sync all thread data access + __syncthreads(); + } + + template + CUTE_DEVICE + static void + copy_vec_pred( + TensorSrc dSrc, + TensorDst dDst, + int threadIdx_X, + int valid_rows, + int valid_cols) { + + constexpr bool IsRowMajor = cute::is_same_v; + using Element = typename TensorSrc::element_type; + constexpr bool IsQmmaF6 = cute::sizeof_bits_v == 6; + + CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dSrc) needs to be static"); + CUTE_STATIC_ASSERT(cute::is_static_v, "shape(dDst) needs to be static"); + CUTE_STATIC_ASSERT(cute::sizeof_bits_v == cute::sizeof_bits_v, + "dSrc and dDst need to have same element bit width"); + CUTE_STATIC_ASSERT(cute::size(dSrc) == cute::size(dDst), "dSrc and dDst need to have same size"); + + // ValueShape + using ValueShape = + cute::conditional_t, Int<1>>, + cute::conditional_t, Int<128 / sizeof_bits_v>>, + Shape>, Int<1>>> + >; + + constexpr int ValueShapeRows = shape<0>(ValueShape{}); + constexpr int ValueShapeCols = shape<1>(ValueShape{}); + + // ThreadShape + using ThreadShape = + cute::conditional_t, Int<1>>, + Shape, Int>>, + cute::conditional_t(dSrc) / ValueShapeCols)>, Int< (shape<1>(dSrc) / ValueShapeCols)>>, + Shape(dSrc) / ValueShapeRows)>, Int(dSrc) / ValueShapeRows)>>> + >; + + constexpr int ThreadShapeRows = shape<0>(ThreadShape{}); + constexpr int ThreadShapeCols = shape<1>(ThreadShape{}); + + const int threadIdx_X_row = threadIdx_X / ThreadShapeCols; + const int threadIdx_X_col = threadIdx_X % ThreadShapeCols; + + // Row Major + if constexpr (IsRowMajor) { + CUTE_UNROLL + for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { + CUTE_UNROLL + for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { + CUTE_UNROLL + for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) { + CUTE_UNROLL + for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { + const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; + const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; + if constexpr ( (not pred) and (not IsQmmaF6) ) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + else { + if (row_i < valid_rows && col_i < valid_cols) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + } + } + } + } + } + } + // Col Major + else { + CUTE_UNROLL + for (int col_chunk_i = 0; col_chunk_i < cutlass::ceil_div(shape<1>(dSrc) , ThreadShapeCols * ValueShapeCols); ++col_chunk_i) { + CUTE_UNROLL + for (int iter_row_blk = 0; iter_row_blk < cutlass::ceil_div(shape<0>(dSrc), ThreadShapeRows * ValueShapeRows); ++iter_row_blk) { + CUTE_UNROLL + for (int iter_col_thr = 0; iter_col_thr < ValueShapeCols; ++iter_col_thr) { + CUTE_UNROLL + for (int iter_row_thr = 0; iter_row_thr < ValueShapeRows; ++iter_row_thr) { + const int row_i = (iter_row_blk * ThreadShapeRows + threadIdx_X_row) * ValueShapeRows + iter_row_thr; + const int col_i = (col_chunk_i * ThreadShapeCols + threadIdx_X_col) * ValueShapeCols + iter_col_thr; + if constexpr ( (not pred) and (not IsQmmaF6) ) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + else { + if (row_i < valid_rows && col_i < valid_cols) { + dDst(row_i, col_i) = dSrc(row_i, col_i); + } + } + } + } + } + } + } + + // Sync all thread data access + __syncthreads(); + } // end of copy_vec_pred() + +}; + +} // namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9f23535fea5df8df728b7c806d65f75f28c36aa3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/kernel/sparse_gemm_compressor.hpp @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils for structured sparse kernels +*/ + +#pragma once + +#include // std::fill +#include // std::array +#include // std::mt19937 + +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor +#include "cutlass/arch/arch.h" // cutlass::arch::SmXY +#include "cutlass/detail/dependent_false.hpp" // cutlass::detail::dependent_false +#include "cutlass/gemm/gemm.h" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes + +#include "cutlass/transform/kernel/sm90_sparse_gemm_compressor.hpp" + +namespace cutlass::transform::kernel { + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class StructuredSparseCompressorUtility { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + //* EltA + using ElementA = ElementA_; + using LayoutATag = LayoutATag_; + using StrideA = cutlass::gemm::TagToStrideA_t; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + + //* EltE + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + + //* AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + //* Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + StructuredSparseCompressorUtility() = default; + + StructuredSparseCompressorUtility(ProblemShape problem, StrideA dA) { + set_problem_size(problem, dA); + } + + void set_problem_size(ProblemShape problem, StrideA dA_) { + M = cute::size<0>(problem); + K = cute::size<2>(problem); + L = cute::size<3>(problem); + + // The following three vars are logical elem count! + K_alignedA = round_up(K, TensorAAlignmentK); + M_alignedA = round_up(M, TensorAAlignmentM); + K_alignedE = round_up(K, TensorEAlignmentK); + M_alignedE = round_up(M, TensorEAlignmentM); + + dA = dA_; + } + + /** + * @brief Get the TensorE number of ElementE along K after alignment requirement + * + * @return int : number of ElementE (uint8_t) along K-dim + */ + int get_metadata_m_physical() const { + return M_alignedE; + } + + /** + * @brief Get the TensorE number of ElementE along M after alignment requirement + * + * @return int : number of ElementE (uint8_t) along M-dim + */ + int get_metadata_k_physical() const { + return K_alignedE / ElementEMmaSparsity{}; + } + + /** + * @brief Get the TensorACompressed number of ElementA along K after alignment requirement + * + * @return int : number of ElementA along K-dim + */ + int get_tensorA_k_physical() const { + return K_alignedA / ElementASparsity{}; + } + + /** + * @brief Get the TensorACompressed number of ElementA along M after alignment requirement + * + * @return int : number of ElementA along M-dim + */ + int get_tensorA_m_physical() const { + return M_alignedA; + } + + /** + * @brief Get the TensorACompressed Bytes + * + * @return uint64_t bytes + */ + uint64_t get_compressed_tensor_A_bytes() const { + const auto tensor_a_comp_num_elt_a = get_tensorA_m_physical() * get_tensorA_k_physical() * L; + const auto tensor_a_comp_bytes = cutlass::bits_to_bytes(tensor_a_comp_num_elt_a * cute::sizeof_bits_v); + return tensor_a_comp_bytes; + } + + /** + * @brief Get the TensorA Bytes + * + * @return uint64_t bytes + */ + uint64_t get_raw_tensor_A_bytes() const { + const auto tensor_a_num_elt_a = uint64_t(M) * uint64_t(K) * uint64_t(L); + const auto tensor_a_bytes = cutlass::bits_to_bytes(tensor_a_num_elt_a * cute::sizeof_bits_v); + return tensor_a_bytes; + } + + /** + * @brief Get the TensorE Bytes + * + * @return uint64_t bytes + */ + uint64_t get_tensor_E_bytes() const { + const auto tensor_e_num_elt_a = uint64_t(get_metadata_m_physical()) * uint64_t(get_metadata_k_physical()) * uint64_t(L); + const auto tensor_e_bytes = cutlass::bits_to_bytes(tensor_e_num_elt_a * cute::sizeof_bits_v); + return tensor_e_bytes; + } + + constexpr auto fill_layoutA_from_compressor() const { + return SparseConfig::fill_layoutA(cute::make_tuple(M,_1{},K,L)); + } + + constexpr auto fill_layoutE_from_compressor() const { + return SparseConfig::fill_layoutE(cute::make_tuple(M,_1{},K,L)); + } + + void structure_sparse_zero_mask_fill(void* host_a_ptr, uint64_t seed) { + + constexpr int ChunkSize = LogicalElemsAMmaRawPerChunk; + using ChunkElement = cute::uint_bit_t>; + + cute::Tensor gA_eltA = cute::make_tensor( + cute::recast_ptr(host_a_ptr), + cute::make_layout(make_shape(M, K, L), dA)); + + // Input TensorA is handled in unit of ElementAMmaRaw instead of ElementA + cute::Tensor gA = cute::recast(gA_eltA); + + // Extract out the Chunk from K-mode + Tensor gA_chunk = cute::zipped_divide(gA, cute::Shape<_1,cute::Int>{}); // (Chunk, Rest) + + // Half of the data is zero to indicate sparsityA = 2 + std::array nnzb_indicator{}; + for (size_t i = 1; i < nnzb_indicator.size(); i += 2) { + nnzb_indicator.at(i) = 1; + } + + std::mt19937 rng(seed); + auto rest_shape = cute::shape<1>(gA_chunk); + for (auto iter = cute::make_coord_iterator(rest_shape); iter != cute::ForwardCoordIteratorSentinel{}; ++iter) { + std::shuffle(nnzb_indicator.begin(), nnzb_indicator.end(), rng); + for (int c = 0; c < size<0>(gA_chunk); ++c) { // for each elem within chunk + if (nnzb_indicator[c] == 0) { + gA_chunk(c, *iter) = ChunkElement{0}; + } + } // end of within chunk + } // end of chunk_idx + } + + int M{-1}; + int K{-1}; + int L{-1}; + StrideA dA{}; + +private: + int K_alignedA{-1}; + int M_alignedA{-1}; + int K_alignedE{-1}; + int M_alignedE{-1}; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig, + class ArchTag +> +struct StructuredSparseCompressorSelector { + static_assert(cutlass::detail::dependent_false, + "Could not select a structured sparse compressor for given parameters."); +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig +> +struct StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + arch::Sm90> { + using Compressor = SM90StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig + >; +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig +> +struct StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + arch::Sm100> { + using Compressor = SM90StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig + >; +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig +> +struct StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + arch::Sm120> { + using Compressor = SM90StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig + >; +}; + +template< + class ProblemShape, + class ElementA, + class LayoutATag, + class SparseConfig, + class ArchTag +> +using StructuredSparseCompressor = typename StructuredSparseCompressorSelector< + ProblemShape, + ElementA, + LayoutATag, + SparseConfig, + ArchTag +>::Compressor; + +} // End namespace cutlass::transform::kernel diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h new file mode 100644 index 0000000000000000000000000000000000000000..ef553aab2043775758c2a87d422456dc5cca2426 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h @@ -0,0 +1,926 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing how threads are mapped to a given tile. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { + +//////////////////////////////////////////////////////////////////////////////// + +/// Strip-mines a pitch-linear tile among a given number of threads, first along +/// the contiguous dimension then along the strided dimension. +/// +/// The tile must be divisible by the thread count such that all threads may +/// execute the same number of iterations with the same delta to exhaustively +/// cover the tile. +/// +/// This class satisfies the "RegularThreadMapping" concept. +/// +/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor +/// kernels. +template < + typename Shape_, + int Threads, + int ElementsPerAccess = 1 +> +struct PitchLinearStripminedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal implementation details + struct Detail { + + static_assert(!(Shape::kContiguous % kElementsPerAccess), ""); + + /// Shape of the tile in units of vectors + using ShapeVec = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided + >; + + static_assert((Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || + (!(kThreads % ShapeVec::kContiguous)), + "Shape must be divisible by number of iterations of each thread."); + }; + + /// Number of iterations by each thread + using Iterations = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + // Redo the comparison here to work around divide by zero compiler + // error. The compiler evaluates both path of platform::conditional. + (Threads >= Detail::ShapeVec::kContiguous + ? (Detail::ShapeVec::kStrided + (kThreads / Detail::ShapeVec::kContiguous - 1)) / + (kThreads / Detail::ShapeVec::kContiguous) + : 0)>, + layout::PitchLinearShape>::type; + + + /// Interval between accesses along each dimension of the tensor's logical coordinate space + /// (in units of Elements) + using Delta = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + kThreads / Detail::ShapeVec::kContiguous + >, + layout::PitchLinearShape< + kThreads * kElementsPerAccess, + 1 + > + >::type; + + /// Shape of the tile in units of vectors + using StorageShape = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + /// (in units of Elements) + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + return TensorCoord( + (thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess, + thread_id / Detail::ShapeVec::kContiguous); + } +}; + +/// This ThreadMap is used by GEMV +template < + typename Shape, + int Threads, + int ElementsPerAccess = 1 +> +struct PitchLinearTilePolicyStripminedThreadContiguous +{ + static_assert((Shape::kContiguous % (Threads * ElementsPerAccess)) == 0, + "Contiguous shape must divide number of threads"); + + using TensorCoord = layout::PitchLinearCoord; + + static int const kThreads = Threads; + static int const kElementsPerAccess = ElementsPerAccess; + + using Iterations = layout::PitchLinearShape< + Shape::kContiguous / (kThreads * kElementsPerAccess), + Shape::kStrided>; + + using Delta = layout::PitchLinearShape<1, 1>; + + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) + { + return TensorCoord(thread_id * Iterations::kContiguous * kElementsPerAccess, 0); + } +}; + +template < + typename Shape, + int Threads, + int ElementsPerAccess = 1 +> +struct PitchLinearTilePolicyStripminedThreadStrided +{ + static_assert((Shape::kStrided % Threads == 0), + "Strided shape must divide number of threads"); + + using TensorCoord = layout::PitchLinearCoord; + + static int const kThreads = Threads; + static int const kElementsPerAccess = ElementsPerAccess; + + using Iterations = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided / kThreads>; + + using Delta = layout::PitchLinearShape<1, 1>; + + using ShapeVec = Shape; + + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) + { + + return TensorCoord(0, thread_id * Iterations::kStrided); + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous +/// elements. +/// +/// This ThreadMap is used by tensor core kernels. +template < + typename Shape_, + int Threads, + typename WarpThreadArrangement_, + int ElementsPerAccess = 1 +> +struct PitchLinearWarpRakedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// Fixed arrangement of threads within a warp (units of threads). + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Number of threads per warp + static int const kWarpSize = WarpThreadArrangement::kCount; + + /// Number of participating warps + static int const kWarpCount = kThreads / kWarpSize; + + static_assert( + !(Shape::kContiguous % kElementsPerAccess), + "Shape must be divisible by vector length."); + + /// Compute the 'shape' of the overall tile in units of vectors + using ShapeInAccesses = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided + >; + + static_assert( + !(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous), + "ShapeInAccesses must be divisible by WarpThreadArrangement."); + + static_assert( + !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided), + "ShapeInAccesses must be divisible by WarpThreadArrangement."); + + // compute number of warp-level accesses total + using WarpAccessIterations = layout::PitchLinearShape< + ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, + ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided + >; + + // Divide it into the number of warps, first partitioning the strided dimension then the + // contiguous. + static int const kWarpsStrided = + (WarpAccessIterations::kStrided >= kWarpCount + ? kWarpCount + : WarpAccessIterations::kStrided); + + static int const kWarpsContiguous = + (kWarpCount > WarpAccessIterations::kStrided + ? kWarpCount / kWarpsStrided + : 1); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = layout::PitchLinearShape< + kWarpsContiguous, kWarpsStrided + >; + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = layout::PitchLinearShape< + Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, + Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided + >; + + static_assert(Iterations::kCount, + "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = layout::PitchLinearShape< + Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, + Detail::WarpThreadArrangement::kStrided + >; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided * Iterations::kStrided + }; + + // This is the offset of a specific warp (in units of vectors) + layout::PitchLinearCoord warp_offset{ + (warp_id % Detail::kWarpsContiguous), + (warp_id / Detail::kWarpsContiguous) + }; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous + }; + + // This is the offset of a thread within a threadblock tile (units of vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided() + }; + + return thread_offset_in_threadblock_tile_base; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous +/// elements. Warps are arranged based on a stride. +/// +/// This ThreadMap is used by tensor core kernels for NCxHWx layout. +template < + typename Shape_, + int Threads, + typename WarpThreadArrangement_, + int ElementsPerAccess = 1 +> +struct PitchLinearStridedWarpRakedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Base ThreadMap + using BaseThreadMap = PitchLinearWarpRakedThreadMap< + Shape, + kThreads, + WarpThreadArrangement, + kElementsPerAccess + >; + + /// Shape of access by each thread + using ThreadAccessShape = typename BaseThreadMap::ThreadAccessShape; + + + struct Detail { + + using WarpThreadArrangement = WarpThreadArrangement_; + + using WarpAccessIterations = typename BaseThreadMap::Detail::WarpAccessIterations; + + static int const kWarpSize = BaseThreadMap::Detail::kWarpSize; + + static int const kWarpCount = BaseThreadMap::Detail::kWarpCount; + + using ShapeInAccesses = typename BaseThreadMap::Detail::ShapeInAccesses; + + // Divide it into the number of warps, first partitioning the contiguous dimension then the + // stride. + static int const kWarpsContiguous = + (WarpAccessIterations::kContiguous >= kWarpCount + ? kWarpCount + : WarpAccessIterations::kContiguous); + + static int const kWarpsStrided = + (kWarpCount > WarpAccessIterations::kContiguous + ? kWarpCount / kWarpsContiguous + : 1); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = layout::PitchLinearShape< + kWarpsContiguous, kWarpsStrided + >; + + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = layout::PitchLinearShape< + Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, + Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided + >; + + static_assert(Iterations::kCount, + "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = typename BaseThreadMap::Delta; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided * Iterations::kStrided + }; + + // This is the offset of a specific warp (in units of vectors) + layout::PitchLinearCoord warp_offset{ + (warp_id % Detail::kWarpsContiguous), + (warp_id / Detail::kWarpsContiguous) + }; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous + }; + + // This is the offset of a thread within a threadblock tile (units of vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided() + }; + + return thread_offset_in_threadblock_tile_base; + } + + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Transpose the existing ThreadMap. For example, interleaved layout is like +/// congruous in the global memory and crosswise in the shared memory. We need +/// to transpose the coordinates between two. + +template +struct TransposePitchLinearThreadMap { + /// Underlying ThreadMap + using ThreadMap = ThreadMap_; + + /// Tensor coordinate + using TensorCoord = typename ThreadMap::TensorCoord; + + /// Tile shape + using Shape = typename ThreadMap::Shape; + + /// Number of threads total + static int const kThreads = ThreadMap::kThreads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal details made public to facilitate introspection + struct Detail { + /// Fixed arrangement of threads within a warp (units of threads). + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Number of threads per warp + static int const kWarpSize = WarpThreadArrangement::kCount; + + /// Number of participating warps + static int const kWarpCount = kThreads / kWarpSize; + + static_assert(!(Shape::kContiguous % kElementsPerAccess), + "Shape must be divisible by vector length."); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = + layout::PitchLinearShape; + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = + layout::PitchLinearShape; + + static_assert(Iterations::kContiguous == 1, + "Contiguous iteration has to be one to reuse the same shared store function with those that don't need transpose"); + + static_assert(Iterations::kCount, "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = + layout::PitchLinearShape; + + /// Maps thread ID to a coordinate offset within the tensor's logical + /// coordinate space Note this is slightly different from the one of + /// PitchLinearWarpRakedThreadMap. + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access + // (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided * Iterations::kStrided}; + + // This is the offset of a specific warp (in units of vectors) + // Note the order of / and %. Also the 2nd operand is kStrided. + layout::PitchLinearCoord warp_offset{ + (warp_id / Detail::WarpArrangement::kStrided), + (warp_id % Detail::WarpArrangement::kStrided)}; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous}; + + // This is the offset of a thread within a threadblock tile (units of + // vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of + // elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided()}; + + return thread_offset_in_threadblock_tile_base; + } +}; + +template +struct TransposePitchLinearThreadMapSimt { + /// Underlying ThreadMap + using ThreadMap = ThreadMap_; + + /// Tensor coordinate + using TensorCoord = typename ThreadMap::TensorCoord; + + /// Tile shape + using Shape = typename ThreadMap::Shape; + + /// Number of threads total + static int const kThreads = ThreadMap::kThreads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static_assert(kElementsPerAccess == 1 , "Simt transpose requires elements per access to be 1"); + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = + layout::PitchLinearShape; + + static_assert(Iterations::kCount, "Number of iterations must be non-zero"); + + static_assert(Iterations::kStrided == 1, + "Strided iteration has to be one to reuse the same shared store function with those that don't need transpose"); + + /// Shape of access by each thread + using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = + layout::PitchLinearShape; + + + /// Maps thread ID to a coordinate offset within the tensor's logical + /// coordinate space Note this is slightly different from the one of + /// PitchLinearWarpRakedThreadMap. + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + TensorCoord coord = ThreadMap::initial_offset(thread_id); + + return TensorCoord( + coord.strided(), + coord.contiguous() + ); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Policy defining a warp-striped arrangement. This partitions a tile into vectorized memory +/// accesses performed by each warp then distributes warps across them. Warps are striped in the +/// strided dimension and raked across the contiguous dimension. +template < + typename Shape_, /// Overall shape to partition in units of elements + int Threads, /// Number of partiticipation threads + typename WarpThreadArrangement_, /// Describes the shape of one memory access per warp + int ElementsPerAccess = 1 /// Number of elements accessed by each thread per memory operation (i.e. vector size) +> +struct PitchLinearWarpStripedThreadMap { + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ElementsPerAccess; + + /// Shape of access by each thread + using ThreadAccessShape = layout::PitchLinearShape; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// Fixed arrangement of threads within a warp (units of threads). + using WarpThreadArrangement = WarpThreadArrangement_; + + /// Number of threads per warp + static int const kWarpSize = WarpThreadArrangement::kCount; + + /// Number of participating warps + static int const kWarpCount = kThreads / kWarpSize; + + static_assert( + !(Shape::kContiguous % kElementsPerAccess), + "Shape must be divisible by vector length."); + + /// Compute the 'shape' of the overall tile in units of vectors + using ShapeInAccesses = layout::PitchLinearShape< + Shape::kContiguous / kElementsPerAccess, + Shape::kStrided + >; + + // compute number of warp-level accesses total + using WarpAccessIterations = layout::PitchLinearShape< + ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, + ShapeInAccesses::kStrided / WarpThreadArrangement::kStrided + >; + + // Divide it into the number of warps, first partitioning the strided dimension then the + // contiguous. + static int const kWarpsStrided = + (WarpAccessIterations::kStrided >= kWarpCount + ? kWarpCount : (kWarpCount / WarpAccessIterations::kStrided)); + + static int const kWarpsContiguous = + (kWarpCount > WarpAccessIterations::kStrided ? + WarpAccessIterations::kContiguous / kWarpsStrided : 1); + + /// Arrangement of warps within a threadblock-scoped tile + using WarpArrangement = layout::PitchLinearShape< + kWarpsContiguous, kWarpsStrided + >; + }; + + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = layout::PitchLinearShape< + Detail::WarpAccessIterations::kContiguous / Detail::kWarpsContiguous, + Detail::WarpAccessIterations::kStrided / Detail::kWarpsStrided + >; + + static_assert(Iterations::kCount, + "Number of iterations must be non-zero"); + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = layout::PitchLinearShape< + Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess, + Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided + >; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + int warp_id = (thread_id / Detail::kWarpSize); + int lane_id = (thread_id % Detail::kWarpSize); + + // + // compute warp-level offset + // + + // This is the shape of the entire area covered by a warp's memory access (in units of vectors) + layout::PitchLinearCoord warp_footprint{ + Detail::WarpThreadArrangement::kContiguous * Iterations::kContiguous, + Detail::WarpThreadArrangement::kStrided + }; + + // This is the offset of a specific warp (in units of vectors) + layout::PitchLinearCoord warp_offset{ + (warp_id % Detail::kWarpsContiguous), + (warp_id / Detail::kWarpsContiguous) + }; + + // This is the offset of a specific thread within a warp (units of vectors) + layout::PitchLinearCoord thread_offset_in_warp{ + lane_id % Detail::WarpThreadArrangement::kContiguous, + lane_id / Detail::WarpThreadArrangement::kContiguous + }; + + // This is the offset of a thread within a threadblock tile (units of vectors) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_vec = + warp_footprint * warp_offset + thread_offset_in_warp; + + // This is the offset of a thread within a threadblock tile (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile_base{ + thread_offset_in_threadblock_tile_vec.contiguous() * kElementsPerAccess, + thread_offset_in_threadblock_tile_vec.strided() + }; + + return thread_offset_in_threadblock_tile_base; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous +/// dimension then along the strided dimension, while each thread access a 2D thread-tile. +/// +/// The tile must be divisible by the thread count such that all threads may execute the same +/// number of iterations with the same delta to exhaustively cover the tile. +/// +/// This class satisfies the "RegularThreadMapping" concept. +template < + typename Shape_, + int Threads, + typename ThreadTileShape +> +struct PitchLinear2DThreadTileStripminedThreadMap; + + +template < + typename Shape_, + int Threads +> +struct PitchLinear2DThreadTileStripminedThreadMap >{ + + /// Tensor coordinate + using TensorCoord = layout::PitchLinearCoord; + + /// Tile shape + using Shape = Shape_; + + /// Access Shape of each thread + using ThreadAccessShape = cutlass::layout::PitchLinearShape<4, 4>; + //using ThreadAccessShape = ThreadTileShape; + + /// Number of threads total + static int const kThreads = Threads; + + /// Extract length of each access from Layout + static int const kElementsPerAccess = ThreadAccessShape::kContiguous; + + static_assert(!(kElementsPerAccess % 4) , "kElementsPerAccess, needs to be multiple of 4 (32bits)"); + + /// Internal implementation details + struct Detail { + + static_assert(!(ThreadAccessShape::kContiguous % 4), "ThreadAccessShape, needs to be multiple of 4"); + + static_assert(!(Shape::kContiguous % ThreadAccessShape::kContiguous), ""); + + static_assert(!((Shape::kContiguous * Shape::kStrided) % (kThreads * ThreadAccessShape::kCount)), + "Shape must be divisible thread count * accesses per thread."); + + /// Shape of the tile in units of vectors + using ShapeVec = layout::PitchLinearShape< + Shape::kContiguous / ThreadAccessShape::kContiguous, + Shape::kStrided / ThreadAccessShape::kStrided + >; + + static_assert( + (Threads < ShapeVec::kContiguous && !(ShapeVec::kContiguous % kThreads)) || + (!(kThreads % ShapeVec::kContiguous) && !(ShapeVec::kStrided % (kThreads / ShapeVec::kContiguous))), + "Shape must be divisible by number of iterations of each thread." + ); + }; + + /// Number of iterations by each thread + using Iterations = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + // Redo the comparison here to work around divide by zero compiler + // error. The compiler evaluates both path of platform::conditional. + (Threads >= Detail::ShapeVec::kContiguous + ? Detail::ShapeVec::kStrided / + (kThreads / Detail::ShapeVec::kContiguous) + : 0)>, + layout::PitchLinearShape>::type; + + /// Interval between accesses along each dimension of the tensor's logical coordinate space + /// (in units of Elements) + using Delta = typename platform::conditional< + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + Shape::kContiguous, + kThreads * ThreadAccessShape::kStrided / Detail::ShapeVec::kContiguous + >, + layout::PitchLinearShape< + kThreads * ThreadAccessShape::kContiguous, + 1 + > + >::type; + + /// Maps thread ID to a coordinate offset within the tensor's logical coordinate space + /// (in units of Elements) + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + return TensorCoord( + (thread_id % Detail::ShapeVec::kContiguous) * ThreadAccessShape::kContiguous, + (thread_id / Detail::ShapeVec::kContiguous) * ThreadAccessShape::kStrided); + } +}; + +/// Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping +template +struct TransposePitchLinearThreadMap2DThreadTile { + /// Underlying ThreadMap + using ThreadMap = ThreadMap_; + + /// Tensor coordinate + using TensorCoord = typename ThreadMap::TensorCoord; + + /// Tile shape + using Shape = typename ThreadMap::Shape; + + /// Number of threads total + static int const kThreads = ThreadMap::kThreads; + + /// Extract vector length from Layout + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + + static_assert(kElementsPerAccess > 1 , "Simt transpose requires elements per access to be 1"); + ///< Iterations along each dimension (concept: PitchLinearShape) + using Iterations = + layout::PitchLinearShape; + + static_assert(Iterations::kCount, "Number of iterations must be non-zero"); + + /// Shape of access by each thread + using ThreadAccessShape = typename ThreadMap::ThreadAccessShape; + + ///< Delta between accesses (units of elements, concept: PitchLinearShape) + using Delta = + layout::PitchLinearShape; + + + /// Maps thread ID to a coordinate offset within the tensor's logical + /// coordinate space Note this is slightly different from the one of + /// PitchLinearWarpRakedThreadMap. + CUTLASS_HOST_DEVICE + static TensorCoord initial_offset(int thread_id) { + + TensorCoord coord = ThreadMap::initial_offset(thread_id); + return TensorCoord( + coord.strided(), + coord.contiguous() + ); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h new file mode 100644 index 0000000000000000000000000000000000000000..508cad846e6d6b819c26570e5dcae9844f712089 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/transpose.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Basic copy routines for tensor views +*/ + +#pragma once + +namespace cutlass { +namespace transform { +namespace thread { + +/// Transforms a fragment by doing a transpose +template < + int ElementCount, + typename TransposeShape, + typename Element +> struct Transpose; + +/// Specialization for int8_t 4x4 transpose +template +struct Transpose , int8_t> { + + static const int kElementCount = ElementCount_; + using TransposeShape = layout::PitchLinearShape<4,4>; + using Element = int8_t; + using Fragment = cutlass::Array; + + static_assert(!(kElementCount % TransposeShape::kCount), "Shape needs to be multiple of 16 elements to do a 4x4 transpose"); + + CUTLASS_DEVICE + void transform(Fragment& dst, Fragment& src) { + + // Expose src/dst as int arrays. + int* src_int = reinterpret_cast(&src); + int* dst_int = reinterpret_cast(&dst); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kElementCount / TransposeShape::kCount; i++){ + + int const i0 = 4 * i + 0; + int const i1 = 4 * i + 1; + int const i2 = 4 * i + 2; + int const i3 = 4 * i + 3; + + int a0 = src_int[i0]; + int a1 = src_int[i1]; + int a2 = src_int[i2]; + int a3 = src_int[i3]; + + int b0, b1, b2, b3, c0; + b0 = __byte_perm(a0, a1, 0x0040); + c0 = __byte_perm(a2, a3, 0x0040); + b0 = __byte_perm(b0, c0, 0x5410); + + b1 = __byte_perm(a0, a1, 0x0051); + c0 = __byte_perm(a2, a3, 0x0051); + b1 = __byte_perm(b1, c0, 0x5410); + + b2 = __byte_perm(a0, a1, 0x0062); + c0 = __byte_perm(a2, a3, 0x0062); + b2 = __byte_perm(b2, c0, 0x5410); + + b3 = __byte_perm(a0, a1, 0x0073); + c0 = __byte_perm(a2, a3, 0x0073); + b3 = __byte_perm(b3, c0, 0x5410); + + dst_int[i0] = b0; + dst_int[i1] = b1; + dst_int[i2] = b2; + dst_int[i3] = b3; + } + } +}; + +} // namespace thread +} // namespace layout +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3977af529124dc3db34610046b72145c2a14bf00 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/thread/unary_op.h @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" + +namespace cutlass { +namespace transform { +namespace thread { + +namespace UnaryTransform { + struct Identity; ///< None (i.e., identity) + struct Conjugate; ///< Complex conjugate +} + +/// Element-wise unary operator that transforms one element of a fragment at a time +template< + typename FragmentIn, ///< Input Fragment + typename FragmentOut,///< Output Fragment + typename Transform> ///< Unary transform operator +class UnaryOp +{ + public: + CUTLASS_DEVICE + static FragmentOut execute(FragmentIn &in) + { + static_assert(FragmentIn::kElements == FragmentOut::kElements, "Number of elements must match."); + static_assert(platform::is_same::value || + platform::is_same::value, + "Unary Operator not supported."); + + FragmentOut out; + if (platform::is_same::value ) + { + CUTLASS_PRAGMA_UNROLL + for (int i=0; i < FragmentIn::kElements; ++i){ + out[i] = static_cast(in[i]); + } + } + else if (platform::is_same::value ) + { + for (int i=0; i < FragmentIn::kElements; ++i){ + out[i] = conj(static_cast(in[i])); + } + } + return out; + } +}; + +template +class UnaryOp +{ + public: + CUTLASS_DEVICE + static FragmentIn execute(FragmentIn &in) + { + static_assert(platform::is_same::value || + platform::is_same::value, + "Unary Operator not supported."); + + if (platform::is_same::value ) + { + return in; + } + else if (platform::is_same::value ) + { + for(int i=0; i < FragmentIn::kElements; ++i){ + in[i] = conj(in[i]); + } + } + return in; + } + }; + } + } +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..bd717d678f8234b9fd39f7d22c4de5c231da4c42 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_iterator.h @@ -0,0 +1,199 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for matrix of indices (ellColInd matrix) +*/ + +#pragma once + +namespace cutlass { +namespace transform { +namespace threadblock { + +namespace ell{ + +constexpr unsigned int SmemPow = 8; +constexpr unsigned int SmemStages = 2; +constexpr unsigned int SmemSize = 1 << SmemPow; +constexpr unsigned int SmemMask = (SmemSize*SmemStages-1); + +class SharedStorage{ + public: + Array array; +}; + +class Iterator{ + public: + using Layout = layout::PitchLinear; + using LongIndex = typename Layout::LongIndex; + + private: + const int *gmem_col_idx_; + int *smem_col_idx_; + const int block_size_; + const int base_idx_; + const int k_shape_; + const int ell_increment_; + const int array_length_; + int col_idx_base_; + int residue_; + int counter_; + + int pow2_; + int residue_shape_; + + int smem_offset_; + int smem_stage_; + int gmem_offset_; + + int lane_; + + bool is_pow2_; + bool is_residue_tile_; + + public: + CUTLASS_DEVICE + void load_ell_indices(){ + for(int i=threadIdx.x; i= 0) ? gmem_col_idx : -1; + } + gmem_offset_ += SmemSize; + smem_stage_ ^= 1; + } + + CUTLASS_DEVICE + Iterator( + SharedStorage& shared_storage_base, + const int* col_idx, + const int& block_size, + const int& base_idx, + const int k_shape, + const int& problem_size_k, + const int& ell_stride, + const int& thread_idx) + : residue_(0), + counter_(0), + smem_offset_(0), + smem_stage_(0), + gmem_offset_(0), + block_size_(block_size), + base_idx_(base_idx), + k_shape_(k_shape), + ell_increment_(ell_stride * block_size), + array_length_((problem_size_k + block_size_ - 1) / block_size_), + residue_shape_(problem_size_k % k_shape_), + is_residue_tile_(residue_shape_ != 0), + smem_col_idx_(reinterpret_cast(&shared_storage_base.array)), + gmem_col_idx_(const_cast(col_idx)), + lane_(thread_idx % 32) { + + load_ell_indices(); + __syncthreads(); + + is_pow2_ = ((block_size_ & (block_size_ - 1)) == 0); + if( is_pow2_ && k_shape <= block_size_ ) lane_ = 0; + + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_; + + pow2_ = 0; + while(block_size_ >> (pow2_ + 1)) ++pow2_; + } + + CUTLASS_DEVICE + int get_blocksize(){ + return block_size_; + } + + CUTLASS_DEVICE + Iterator &operator++(){ + if(is_residue_tile_){ + residue_ += residue_shape_; + is_residue_tile_ = false; + } else { + residue_ += k_shape_; + } + + if(residue_ < block_size_){ + return *this; + } + + if((array_length_ > SmemSize) && (((smem_offset_ >> SmemPow) & 1) != smem_stage_)) + load_ell_indices(); + + if(residue_ == block_size_){ + ++smem_offset_; + counter_ += ell_increment_; + residue_ = 0; + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; + return *this; + } + + if(is_pow2_){ + smem_offset_ += residue_ >> pow2_; + counter_ += (residue_ >> pow2_) * ell_increment_; + residue_ = residue_ & ((1 << pow2_) - 1); + } + else { + smem_offset_ += residue_ / block_size_; + counter_ += (residue_ / block_size_) * ell_increment_; + residue_ %= block_size_; + } + + col_idx_base_ = smem_col_idx_[(smem_offset_ + lane_) & SmemMask] * ell_increment_ - counter_; + + return *this; + } + + CUTLASS_DEVICE + LongIndex get_offset(const int& idx) { + int num_jump_tiles; + if(is_pow2_) + num_jump_tiles = (idx + residue_) >> pow2_; + else + num_jump_tiles = (idx + residue_) / block_size_; + + int tmp = __shfl_sync(0xffffffff, col_idx_base_, num_jump_tiles); + return tmp - num_jump_tiles * ell_increment_; + } + + CUTLASS_DEVICE + LongIndex get_offset_fast() { + return col_idx_base_; + } +}; + +} +} +} +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..3676c2339067f9eaad667e11e0d798ae3f4d5c95 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h @@ -0,0 +1,1350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaMultistage +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// EllPredicatedTileAccessIterator +/// +template +class EllPredicatedTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = + (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend EllPredicatedTileAccessIterator; + + private: + /// stride of pitch-linear layout (units of Element) + LongIndex stride_; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : stride_(layout.stride(0)) { + inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * + ThreadMap::Delta::kStrided * LongIndex(stride_) * + sizeof_bits::value / 8; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Offset to the first steady-state tile + TensorCoord residue_offset_; + + /// Initial offset to define ELL block + TensorCoord ell_offset_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + /// Iteration along vectors implied by the thread map + int iteration_vector_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + + if (is_steady_state) { + if (kAdvanceRank == 0) { + guard = (coord.strided() < extent.strided()); + } else { + guard = (coord.contiguous() < extent.contiguous()); + } + } else { + guard = (coord.strided() < extent.strided() && + coord.contiguous() < extent.contiguous()); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + extent_(extent), + is_residue_tile_(true) { + + TensorCoord residue_extent; + if (kAdvanceRank) { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; + if (!residue_size) { + residue_size = Shape::kStrided; + } + + residue_offset_ = make_Coord(0, residue_size); + residue_extent = make_Coord( + extent_.contiguous(), + min(threadblock_offset.strided() + residue_size, extent_.strided()) + ); + } else { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; + if (!residue_size) { + residue_size = Shape::kContiguous; + } + + residue_offset_ = make_Coord(residue_size, 0); + + residue_extent = make_Coord( + min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), + extent_.strided() + ); + } + + // Per-thread offset in logical coordinates of tensor + ell_offset_ = ThreadMap::initial_offset(thread_id); + thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(thread_offset_)); + + compute_predicates_(residue_extent, false); + + set_iteration_index(0); + } + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + thread_offset_ += residue_offset_; + + Layout layout(params_.stride_); + add_pointer_offset(layout(residue_offset_)); + + compute_predicates_(extent_, true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast( + pointer_ + + iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; + } + + /// Returns a k_location + CUTLASS_HOST_DEVICE + int get_k() const { + if(kAdvanceRank){ //strided + return ell_offset_.strided() + iteration_strided_ * ThreadMap::Delta::kStrided; + }else{ + return ell_offset_.contiguous() + iteration_contiguous_ * ThreadMap::Delta::kContiguous + iteration_vector_ * AccessType::kElements; + } + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + if(kAdvanceRank) + return params_.stride_; + else + return 1; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + + Mask mask; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = ell_offset_ + iteration_coord; + + bool guard; + + if (kAdvanceRank == 0) { + guard = (coord.strided() < blocksize); + } else { + guard = (coord.contiguous() < blocksize); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + mask[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] &= predicates_[i]; + } + set_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for column-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class EllPredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileAccessIterator for row-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, + AccessType>; + + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + CUTLASS_HOST_DEVICE + int get_k() const { + return iterator_.get_k(); + } + + CUTLASS_HOST_DEVICE + int get_stride() const { + return iterator_.get_stride(); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileAccessIterator operator++(int) { + EllPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..e377bba4c454267737bffda73b1dff7572174ee7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/ell_predicated_tile_iterator.h @@ -0,0 +1,1315 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Ell iterator for Blocked-Ell matrix (ellValue matrix) used with EllMmaPipelined +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +#include "cutlass/transform/threadblock/ell_predicated_tile_access_iterator.h" +#include "cutlass/transform/threadblock/ell_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// EllPredicatedTileIterator +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::EllPredicatedTileIterator; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess +> +class EllPredicatedTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + EllPredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend EllPredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return address_iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { address_iterator_.ell_add_mask(blocksize); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator &ell_iter) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + address_iterator_.set_iteration_index(idx); + LongIndex ell_offset = 0; + + int k_offset = address_iterator_.get_k(); + ell_offset = ell_iter.get_offset(k_offset) * sizeof(Element); + + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + bool is_valid = address_iterator_.valid(); + is_valid = is_valid && (ell_offset >= 0); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, is_valid); + + ++address_iterator_; + } + } + } + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator &ell_iter) { + + LongIndex ell_offset = ell_iter.get_offset_fast() * sizeof(Element); + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + ell_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + bool is_valid = address_iterator_.valid(); + is_valid = is_valid && (ell_offset >= 0); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, is_valid); + + ++address_iterator_; + } + } + } + } + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class EllPredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + } + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class EllPredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + }; + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): EllPredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { + iterator_.ell_add_mask(blocksize); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for interleaved data. It is mapped +/// to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class EllPredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Iterator for ELL storage + using EllIterator = typename cutlass::transform::threadblock::ell::Iterator; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + CUTLASS_DEVICE + void load_with_ell_index(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index(frag, ell_iter); + } + + CUTLASS_DEVICE + void load_with_ell_index_fast(Fragment &frag, EllIterator& ell_iter) { + iterator_.load_with_ell_index_fast(frag, ell_iter); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of EllPredicatedTileIterator for interleaved-32 data. It is +/// mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class EllPredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = EllPredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend EllPredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a EllPredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : EllPredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + EllPredicatedTileIterator operator++(int) { + EllPredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Returns a stride + CUTLASS_HOST_DEVICE + int get_stride() const { return iterator_.get_stride(); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// add mask for small tiles in ELL + CUTLASS_HOST_DEVICE + void ell_add_mask(int blocksize) { iterator_.ell_add_mask(blocksize); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..dab597c835ced1a4f070858b26da3007d268c04e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -0,0 +1,375 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + It can be used to load the gamma and beta vectors of layernorm which is loop variant. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/threadblock/conv2d_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorAccessIterator +/// +template +class PredicatedScaleBiasVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for fprop pitch-linear data. +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kThreads = ThreadblockShape::kContiguous / kElementsPerAccess; + + using AccessType = AlignedArray; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + BytePointer pointer_; + + TensorCoord thread_offset_; + + int problem_size_k_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + bool guard_; + + TensorCoord::Index residue_size_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Extent of tensor + int problem_size_k, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) { + pointer_ = (thread_id < kThreads) + ? reinterpret_cast( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + problem_size_k_ = problem_size_k; + + is_residue_tile_ = true; + + residue_size_ = (problem_size_k_ - threadblock_offset.contiguous()) % ThreadblockShape::kContiguous; + + if (residue_size_ == 0) { + residue_size_ = ThreadblockShape::kContiguous; + } + + guard_ = ((thread_id - thread_base) * kElementsPerAccess) < residue_size_; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Extent of tensor + int problem_size_k, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorAccessIterator(problem_size_k, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole threadblock tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + guard_ = threadIdx.x < kThreads * 2; + + TensorCoord offset = is_residue_tile_ ? + TensorCoord(residue_size_ + ThreadblockShape::kContiguous * (tile_offset.contiguous() - 1), 0) + : TensorCoord(ThreadblockShape::kContiguous * tile_offset.contiguous(), 0); + + thread_offset_ = + thread_offset_ + + offset; + + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + return reinterpret_cast( + pointer_ + + (thread_offset_.contiguous() * sizeof_bits::value / 8)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + guard_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return guard_; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorAccessIterator { + public: + + using ThreadblockShape = ThreadblockShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorAccessIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + ///< Extent of tensor + int problem_size_k, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(problem_size_k, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + int problem_size_k, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorAccessIterator(problem_size_k, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator operator++(int) { + PredicatedScaleBiasVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..e5d9e70d73bfcbdc27ab78bbedea1278c3b25950 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_scale_bias_vector_iterator.h @@ -0,0 +1,328 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates calculating the address and predicates to the load of scale and bias vectors. + + This iterator uses masks to guard out-of-bounds accesses. + + This can be used to load var and mean vectors in layernorm which is loop invariant. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedScaleBiasVectorIterator +/// +template +class PredicatedScaleBiasVectorIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for wgrad pitch-linear data. +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kElementsPerAccess = 1; + + using AccessType = AlignedArray; + + static int const kIterations = WarpShape::kContiguous / 8; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array<__half2, 2 * kIterations * kElementsPerAccess>; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + ConstPointer scale_pointer_; + ConstPointer bias_pointer_; + + /// Size of tensor + int problem_size_; + + int32_t thread_offset_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Extent of tensor + int problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : problem_size_(problem_size), + scale_pointer_(scale_pointer), + bias_pointer_(bias_pointer) { + + thread_offset_ = threadblock_offset.contiguous() + (thread_id % 32) / 4; + } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + /// Extent of tensor + int problem_size, + /// Pointer to start of scale vector + ConstPointer scale_pointer, + /// Pointer to start of scale vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id) + : PredicatedScaleBiasVectorIterator(problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole warp tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + thread_offset_ += (WarpShape::kContiguous * tile_offset.contiguous()); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.fill(__float2half2_rn(0.0f)); + __half2 *frag_ptr = reinterpret_cast<__half2 *>(&frag); + + // load scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2].x, + scale_pointer_ + thread_offset_ + c * 8, + (thread_offset_ + c * 8) < problem_size_ + ); + } + + // load bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + __half, + sizeof(AccessType) + >( + frag_ptr[c * 2 + 1].x, + bias_pointer_ + thread_offset_ + c * 8, + (thread_offset_ + c * 8) < problem_size_ + ); + } + + // duplicate scale + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2].y = frag_ptr[c * 2].x; + } + + // duplicate bias + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + frag_ptr[c * 2 + 1].y = frag_ptr[c * 2 + 1].x; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedScaleBiasVectorIterator { + public: + + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedScaleBiasVectorIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + using Fragment = typename UnderlyingIterator::Fragment; + + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + ///< Extent of tensor + int problem_size, + ///< Pointer to the start of the scale vector + ConstPointer scale_pointer, + ///< Pointer to the start of the bias vector + ConstPointer bias_pointer, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(problem_size, scale_pointer, bias_pointer, + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorIterator( + int problem_size, ///< Extent of tensor + ConstPointer scale_pointer, ///< Pointer to the start of the scale vector + ConstPointer bias_pointer, ///< Pointer to the start of the bias vector + int thread_id ///< ID of each participating thread + ) + : PredicatedScaleBiasVectorIterator(problem_size, + scale_pointer, bias_pointer, + thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// threadblock tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + iterator_.load(frag); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..3640709868602584f93e3409a251c0baff19d18d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -0,0 +1,2118 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile this + iterator visits maybe partial, then the remaining tiles are complete. So, we + only need to compute the predicates twice, once before the first tile and + once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorPredicates +/// +template +class PredicatedTileAccessIteratorPredicates { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorCoord = typename Layout::TensorCoord; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = + (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + +// private: + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Offset to the first steady-state tile + TensorCoord residue_offset_; + + /// Iteration along vectors implied by the thread map + int iteration_vector_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + + if (is_steady_state) { + if (kAdvanceRank == 0) { + guard = (coord.strided() < extent.strided()); + } else { + guard = (coord.contiguous() < extent.contiguous()); + } + } else { + guard = (coord.strided() < extent.strided() && + coord.contiguous() < extent.contiguous()); + } + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + } + + CUTLASS_HOST_DEVICE + void set_predicates(int thread_id, TensorCoord const &threadblock_offset) { + + TensorCoord residue_extent; + if (kAdvanceRank) { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.strided()) % Shape::kStrided; + if (!residue_size) { + residue_size = Shape::kStrided; + } + + residue_offset_ = make_Coord(0, residue_size); + residue_extent = make_Coord( + extent_.contiguous(), + min(threadblock_offset.strided() + residue_size, extent_.strided()) + ); + } else { + + typename TensorCoord::Index residue_size = (extent_[kAdvanceRank] - threadblock_offset.contiguous()) % Shape::kContiguous; + if (!residue_size) { + residue_size = Shape::kContiguous; + } + + residue_offset_ = make_Coord(residue_size, 0); + + residue_extent = make_Coord( + min(extent_.contiguous(), threadblock_offset.contiguous() + residue_size), + extent_.strided() + ); + } + + // Per-thread offset in logical coordinates of tensor + thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); + + compute_predicates_(residue_extent, false); + + set_iteration_index(0); + } + + /// Default constructor + PredicatedTileAccessIteratorPredicates() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorPredicates( + /// Extent of tensor + TensorCoord extent) + : extent_(extent) { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorPredicates &operator++() { + + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + + + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIterator +/// +template +class PredicatedTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for pitch-linear data. +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, Element, Layout, AdvanceRank, ThreadMap, AccessType>; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static bool constexpr Permute = !platform::is_same::value + && !platform::is_same>::value; + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + + using Base = PredicatedTileAccessIteratorParams; + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : + Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()() + ) { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + + /// Gather indices + int const *indices_; + + /// Function to perform layout permutation and offset computation + PermuteLayout permute_layout_; + + /// Tracks thread's coordinate offset in the matrix for current tile. + /// This is only used in the following cases: + /// - when Gather is true, strided coordinate needed to access indices (contiguous offset is tracked via pointer_) + /// - when Permute is true, both coordinates are needed as input into permutation function (pointer_ is fixed) + TensorCoord coord_offset_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + /// Gather indices + int const *indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + is_residue_tile_(true), + indices_(indices), + permute_layout_(TensorCoord(extent.contiguous(), extent.strided()), params.stride_) { + + the_predicates.set_predicates(thread_id, threadblock_offset); + + if (Gather) { + assert(indices_); + } + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather && !Permute) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + coord_offset_ = the_predicates.thread_offset_; + if (!Permute) { + add_pointer_offset(layout(make_Coord(coord_offset_.contiguous(), 0))); + } + } + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + the_predicates.thread_offset_ += the_predicates.residue_offset_; + + the_predicates.compute_predicates_(the_predicates.extent_, true); + + Layout layout(params_.stride_); + + if (!Gather && !Permute) { + add_pointer_offset(layout(the_predicates.residue_offset_)); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous() * sizeof_bits::value / 8; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided() * sizeof_bits::value / 8; + } + } else { + coord_offset_.strided() = the_predicates.thread_offset_.strided() + Shape::kStrided * (tile_offset.strided() - kAdvanceRank); + if (!Permute) { + add_pointer_offset(layout(make_Coord(the_predicates.residue_offset_.contiguous(), 0))); + add_pointer_offset(Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank))); + } else { + coord_offset_.contiguous() = the_predicates.thread_offset_.contiguous() + Shape::kContiguous * (tile_offset.contiguous() - (1 - kAdvanceRank)); + } + } + } else { + if (!Gather && !Permute) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + coord_offset_.strided() += Shape::kStrided * tile_offset.strided(); + if (!Permute) { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + } else { + coord_offset_.contiguous() += Shape::kContiguous * tile_offset.contiguous(); + } + } + } + + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + if (Gather || Permute) + { + if (!valid()) { + return nullptr; + } + + Index coord_contig = (Permute ? coord_offset_.contiguous() : 0) + the_predicates.iteration_contiguous_ * ThreadMap::Delta::kContiguous + the_predicates.iteration_vector_ * AccessType::kElements; + Index coord_strided = coord_offset_.strided() + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + if (Gather) { + coord_strided = indices_[coord_strided]; + } + + LongIndex offset = Permute ? permute_layout_(TensorCoord(coord_contig, coord_strided)) : (coord_strided * LongIndex(params_.stride_) + coord_contig); + return reinterpret_cast(pointer_ + OffsetBytes(offset)); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather && !Permute) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather && !Permute) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType, + Gather, PermuteLayout>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType, + Gather, PermuteLayout>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + /// Gather indices + int const *indices = nullptr) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank 2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap, AccessType>; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIterator; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_contiguous_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_[1]) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + is_residue_tile_(true) { + + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { the_predicates.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + if (is_residue_tile_) { + + the_predicates.thread_offset_ += the_predicates.residue_offset_; + + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.residue_offset_)); + + the_predicates.compute_predicates_(the_predicates.extent_, true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1] - 1); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0] - 1); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + is_residue_tile_ = false; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(pointer_) + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { the_predicates.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { the_predicates.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { the_predicates.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank 2 column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::AffineRankN<2>, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset(make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for affine rank-2 row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, Element, + layout::AffineRankN<2>, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset(make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator, + AdvanceRank, ThreadMap_, AccessType_, false, + layout::NoPermute> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, + AccessType>; + + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileAccessIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator operator++(int) { + PredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { return iterator_.valid(); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h new file mode 100644 index 0000000000000000000000000000000000000000..93eac72e40ddf6b0f3d268957873417e5d5a442f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h @@ -0,0 +1,834 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last + "residue" tile first, with the objective of minimizing predicate mask updates + during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIterator2dThreadTile +/// +template +class PredicatedTileAccessIterator2dThreadTile; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. +/// +template +class PredicatedTileAccessIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = (ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kStrided + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + + public: + friend PredicatedTileAccessIterator2dThreadTile; + + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : + Base(layout.stride(0), + MakePredicatedTileAccessIteratorDesc()() + ) { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Index of residue tile + int residue_tile_idx_; + + /// Used for out-of-order visitation + bool is_residue_tile_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + /// Tracks iterations within the thread loop + int iteration_thread_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_HOST_DEVICE + void compute_predicates_( + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + } + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++) { + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous, + ts + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + + if (is_steady_state) { + if (kAdvanceRank == 0) { + guard = (coord.strided() < extent_.strided()); + } else { + guard = (coord.contiguous() < extent_.contiguous()); + } + } else { + guard = (coord.strided() < extent_.strided() && + coord.contiguous() < extent_.contiguous()); + } + + int pred_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + } + } + + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + extent_(extent), + is_residue_tile_(true) { + + + TensorCoord residue_offset; + if (kAdvanceRank) { + residue_tile_idx_ = + (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / + Shape::kStrided; + residue_offset = make_Coord(0, residue_tile_idx_ * Shape::kStrided); + } else { + residue_tile_idx_ = + (extent_[kAdvanceRank] - threadblock_offset[kAdvanceRank] - 1) / + Shape::kContiguous; + residue_offset = make_Coord(residue_tile_idx_ * Shape::kContiguous, 0); + } + + // Per-thread offset in logical coordinates of tensor + thread_offset_ = threadblock_offset + residue_offset + + ThreadMap::initial_offset(thread_id); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(thread_offset_)); + + compute_predicates_(false); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + int residual = index % (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); + iteration_strided_ = index / (ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided); + + iteration_contiguous_ = residual / ThreadMap::ThreadAccessShape::kStrided; + iteration_thread_ = residual % ThreadMap::ThreadAccessShape::kStrided; + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += int(sizeof(Element)) * pointer_offset; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + if (is_residue_tile_) { + TensorCoord residue_offset; + if (kAdvanceRank) { + residue_offset = TensorCoord(0, residue_tile_idx_ * Shape::kStrided); + } else { + residue_offset = TensorCoord(residue_tile_idx_ * Shape::kContiguous, 0); + } + + thread_offset_ -= residue_offset; + + Layout layout(params_.stride_); + add_pointer_offset(-layout(residue_offset)); + + compute_predicates_(true); + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * tile_offset.strided(); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * tile_offset.contiguous(); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } + is_residue_tile_ = false; + } + + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *ret_val = reinterpret_cast( + pointer_ + (iteration_thread_ * params_.stride_ + iteration_contiguous_ * ThreadMap::Delta::kContiguous) * int(sizeof(Element))); + + return ret_val; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile &operator++() { + + iteration_thread_++; + + if (iteration_thread_ < ThreadMap::ThreadAccessShape::kStrided) + return *this; + + iteration_thread_ = 0; + + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile operator++(int) { + PredicatedTileAccessIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + + int pred_idx = + iteration_thread_ + + iteration_contiguous_ * ThreadMap::ThreadAccessShape::kStrided + + iteration_strided_ * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + + return pred; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile operator++(int) { + PredicatedTileAccessIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIterator2dThreadTile< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))){} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIterator2dThreadTile operator++(int) { + PredicatedTileAccessIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h new file mode 100644 index 0000000000000000000000000000000000000000..5e509a344e955438ea4eabe6806ed2ab79343d36 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/detail/helper_macros.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Predicated tile access iterator descriptor object containing template dependent state +struct PredicatedTileAccessIteratorDesc { + + int element_size_bits = -1; + int advance_rank = -1; + layout::PitchLinearCoord threadblock_shape; + layout::PitchLinearCoord threadmap_iterations; + layout::PitchLinearCoord threadmap_delta; + + // + // Methods + // + + PredicatedTileAccessIteratorDesc() = default; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc( + int element_size_bits_, + int advance_rank_, + layout::PitchLinearCoord threadblock_shape_, + layout::PitchLinearCoord threadmap_iterations_, + layout::PitchLinearCoord threadmap_delta_ + ): + element_size_bits(element_size_bits_), + advance_rank(advance_rank_), + threadblock_shape(threadblock_shape_), + threadmap_iterations(threadmap_iterations_), + threadmap_delta(threadmap_delta_) + { + #if 0 + printf("PredicatedTileAccessIteratorDesc(%d, %d, {%d, %d}, {%d, %d}, {%d, %d}})\n", + element_size_bits, + advance_rank, + threadblock_shape.contiguous(), threadblock_shape.strided(), + threadmap_iterations.contiguous(), threadmap_iterations.strided(), + threadmap_delta.contiguous(), threadmap_delta.strided()); + #endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper template to construct an PredicatedTileAccessIteratorDesc from a template +// dependent state +template < + typename Shape, typename Element, typename Layout, + int AdvanceRank, typename ThreadMap> + struct MakePredicatedTileAccessIteratorDesc; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for pitch-linear data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::PitchLinear, AdvanceRank, ThreadMap> { + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return PredicatedTileAccessIteratorDesc( + sizeof_bits::value, + AdvanceRank, + {Shape::kContiguous, Shape::kStrided}, + {ThreadMap::Iterations::kContiguous, ThreadMap::Iterations::kStrided}, + {ThreadMap::Delta::kContiguous, ThreadMap::Delta::kStrided} + ); +} + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::ColumnMajor, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for row-major data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::RowMajor, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for column-major interleaved data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap, int InterleavedK> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::ColumnMajorInterleaved, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + static int const kInterleavedK = InterleavedK; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIterator for roww-major interleaved data. +template < + typename Shape, typename Element, int AdvanceRank, + typename ThreadMap, int InterleavedK> +struct MakePredicatedTileAccessIteratorDesc < + Shape, Element, layout::RowMajorInterleaved, AdvanceRank, ThreadMap> { + + static int const kAdvanceRank = AdvanceRank; + static int const kInterleavedK = InterleavedK; + + using UnderlyingMakeOperator = MakePredicatedTileAccessIteratorDesc< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap>; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorDesc operator()() { + + return UnderlyingMakeOperator()(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Parameters struct +// + +struct PredicatedTileAccessIteratorParams { + + using Index = int32_t; + using LongIndex = int64_t; + + // + // Data members + // + /// stride of pitch-linear layout (units of Element) + LongIndex stride_ = 0; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_ = 0; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_ = 0; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Status initialize(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { + CUTLASS_ASSERT(desc.element_size_bits > 0); + CUTLASS_ASSERT(desc.advance_rank == 0 || desc.advance_rank == 1); + + stride_ = stride; + + inc_strided_ = (LongIndex(stride_) * desc.threadmap_delta.strided()) * + desc.element_size_bits / 8; + + if (desc.advance_rank) { + // advance along strided dimension + inc_advance_ = + desc.threadblock_shape.strided() * LongIndex(stride_) * desc.element_size_bits / 8; + } else { + // advance along contiguous dimension + inc_advance_ = desc.threadblock_shape.contiguous() * desc.element_size_bits / 8; + } + + inc_next_ = inc_advance_ - LongIndex(desc.threadmap_iterations.strided() - 1) * + desc.threadmap_delta.strided() * LongIndex(stride_) * + desc.element_size_bits / 8; + + return Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Status initialize(Index stride, PredicatedTileAccessIteratorDesc desc) { + return initialize(LongIndex(stride), desc); + } + + PredicatedTileAccessIteratorParams() = default; + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorParams(Index stride, PredicatedTileAccessIteratorDesc desc) { + initialize(stride, desc); + } + + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorParams(LongIndex stride, PredicatedTileAccessIteratorDesc desc) { + initialize(stride, desc); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..f657fe25813567b47156047f6ef023b678ac097f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h @@ -0,0 +1,892 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last + "residue" tile first, with the objective of minimizing predicate mask updates + during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorTriangularMatrix +/// +template +class PredicatedTileAccessIteratorTriangularMatrix; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for pitch-linear data. +/// +template +class PredicatedTileAccessIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + using CompareOp = typename TrMatrixCompareOp::Type; + + static_assert( kFillMode == FillMode::kFull || + ((kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) && AccessType::kElements == 1), + "BLAS3 iterator for the triangular/symmetric matrix must use AccessType::kElements as 1"); + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + static int const kPredicatesPerByte = 4; + static int const kPredicatesPerWord = 4 * kPredicatesPerByte; + + static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector; + + /// Number of 32b words containing predicates + static int const kPredicateByteCount = + (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte; + static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4; + + static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u; + + static_assert(kPredicateWordCount <= 4, "Too many predicates."); + + /// Predicate vector stores mask to guard accesses + using Mask = Array; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorTriangularMatrix; + + private: + /// stride of pitch-linear layout (units of Element) + StrideIndex stride_; + /// (true) pitch-linear layout is mapped to row-major matrix + /// (false) pitch-linear layout is mapped to column-major matrix + bool is_row_major_; + /// for vectorized access across the diagonal boundary guard condition is + /// checked for the element on the boundary + int access_diagonal_boundary_; + /// amount (in byte) to increment pointer to move to next access along + /// strided dimension + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + + // Default ctor + CUTLASS_HOST_DEVICE + Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0), is_row_major_(false), access_diagonal_boundary_(0) { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout, bool is_row_major, int access_diagonal_boundary) : + stride_(layout.stride(0)), is_row_major_(is_row_major), access_diagonal_boundary_(access_diagonal_boundary) { + + inc_strided_ = (LongIndex(stride_) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = + Shape::kStrided * LongIndex(stride_) * sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = Shape::kContiguous * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - LongIndex(ThreadMap::Iterations::kStrided - 1) * + ThreadMap::Delta::kStrided * LongIndex(stride_) * + sizeof_bits::value / 8; + + }; + + + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const ¶ms_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Guard predicates + uint32_t predicates_[kPredicateWordCount]; + + /// Track global memory addresses on the diagonal + /// To ignore imag part for diagonal elements of hermitian matrices + uint32_t predicates_onDiag_[kPredicateWordCount]; + + /// Size of tensor + TensorCoord extent_; + + /// Initial offset for each thread + TensorCoord thread_offset_; + + /// Iteration along vectors implied by the thread map + int iteration_vector_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0u; + predicates_onDiag_[i] = 0u; + } + + CompareOp compare_op; + + CUTLASS_PRAGMA_UNROLL + for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) { + + int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector); + + int c = access_residual / kAccessesPerVector; + int v = access_residual % kAccessesPerVector; + + TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements, + s * ThreadMap::Delta::kStrided); + + TensorCoord coord = thread_offset_ + iteration_coord; + + bool guard; + bool onDiag = false; + + guard = ((coord.strided() < extent.strided()) && + (coord.contiguous() < extent.contiguous())); + + + // guard access on the wrong side of the triagular matrix diagonal + if (kFillMode == FillMode::kLower || kFillMode == FillMode::kUpper) { + coord += TensorCoord{params_.access_diagonal_boundary_, 0}; + + bool triagular_guard_row_major = compare_op(coord.strided(), coord.contiguous()) | !params_.is_row_major_; + bool triagular_guard_col_major = compare_op(coord.contiguous(), coord.strided()) | params_.is_row_major_; + + guard = guard && triagular_guard_row_major && triagular_guard_col_major; + + if (kDiagType == DiagType::kUnit) { + onDiag = (guard && coord.strided() == coord.contiguous()) ? true : false; + } + } + + int pred_idx_onDiag = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + int word_idx_onDiag = pred_idx_onDiag / kPredicatesPerWord; + int residual_onDiag = pred_idx_onDiag % kPredicatesPerWord; + int byte_idx_onDiag = residual_onDiag / kPredicatesPerByte; + int bit_idx_onDiag = residual_onDiag % kPredicatesPerByte; + + predicates_onDiag_[word_idx_onDiag] |= (unsigned(onDiag) << (byte_idx_onDiag * 8 + bit_idx_onDiag)); + + int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx)); + + } + + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + pointer_(reinterpret_cast(const_cast(pointer))), + extent_(extent) { + + + // Per-thread offset in logical coordinates of tensor + thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(thread_offset_)); + + compute_predicates_(extent_); + + set_iteration_index(0); + } + + /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + thread_offset_ += TensorCoord{0, Shape::kStrided * tile_offset.strided()}; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + thread_offset_ += TensorCoord{Shape::kContiguous * tile_offset.contiguous(), 0}; + } + + compute_predicates_(extent_); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast( + pointer_ + + iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits::value) / 8) + iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix &operator++() { + + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + + iteration_vector_ = 0; + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix operator++(int) { + PredicatedTileAccessIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = enable ? 0u : predicates_[i]; + } + + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = 0xffffffff; + } + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + predicates_[i] = mask[i]; + } + + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kPredicateWordCount; ++i) { + mask[i] = predicates_[i]; + } + } + + /// Return if the address in on the diagonal + CUTLASS_HOST_DEVICE + bool getOnDiag() { + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_onDiag_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + + + int pred_idx = + iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous); + + int word_idx = pred_idx / kPredicatesPerWord; + int residual = pred_idx % kPredicatesPerWord; + int byte_idx = residual / kPredicatesPerByte; + int bit_idx = residual % kPredicatesPerByte; + + bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0; + return pred; + + + //return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, + kSideMode, kFillMode, kDiagType, AccessType>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + static int const kAccessDiagonalBoundary = + (kFillMode == FillMode::kLower) ? (AccessType::kElements - 1) : 0; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0)), false, kAccessDiagonalBoundary){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix operator++(int) { + PredicatedTileAccessIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Return if the address in on the diagonal + CUTLASS_HOST_DEVICE + bool getOnDiag() { + return iterator_.getOnDiag(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorTriangularMatrix for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileAccessIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileAccessIteratorTriangularMatrix< + layout::PitchLinearShape, Element, + layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, + kSideMode, kFillMode, kDiagType, AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + static int const kAccessDiagonalBoundary = + (kFillMode == FillMode::kUpper) ? (AccessType::kElements - 1) : 0; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0)), true, kAccessDiagonalBoundary){}; + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + ///< Precomputed parameters object + Params const ¶ms, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorTriangularMatrix operator++(int) { + PredicatedTileAccessIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Return if the address in on the diagonal + CUTLASS_HOST_DEVICE + bool getOnDiag() { + return iterator_.getOnDiag(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..43c4cbd1a5758e0288f82babbe7043d22f83c009 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h @@ -0,0 +1,1887 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile this + iterator visits maybe partial, then the remaining tiles are complete. So, we + only need to compute the predicates twice, once before the first tile and + once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIterator +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::PredicatedTileIterator; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false, + typename PermuteLayout = layout::NoPermute +> +class PredicatedTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) {} + + /// Default constructor + Params() = default; + + CUTLASS_HOST_DEVICE + Params(Base const &base) + : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + /// Gather indices + int const *indices = nullptr) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset, indices) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather, + typename PermuteLayout +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather, + PermuteLayout + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) + {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()), + indices) + { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather, + typename PermuteLayout +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather, + PermuteLayout + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + + }; + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< Gather indices + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()), + indices + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator, AdvanceRank, + ThreadMap_, AccessSize, false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIterator; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + + friend PredicatedTileIterator; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) {} + + /// Default constructor + Params() = default; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank 2 column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) + {} + }; + +private: + + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for affine rank 2 row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize +> +class PredicatedTileIterator { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + +private: + + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + +public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for interleaved data. It is mapped +/// to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template +class PredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize, false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator for interleaved-32 data. It is +/// mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator, + AdvanceRank, ThreadMap_, AccessSize, false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator< + layout::PitchLinearShape, + Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>; + + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIterator; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + /// Default constructor + Params() = default; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + + /// Default constructor + PredicatedTileIterator() = default; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : iterator_(params.params_, pointer, + layout::PitchLinearCoord(extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator operator++(int) { + PredicatedTileIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h new file mode 100644 index 0000000000000000000000000000000000000000..cbe48df6e7dc1c66c9e55b8eab14aa1fb53bc14b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h @@ -0,0 +1,787 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_2dthreadtile.h" +#include "cutlass/transform/thread/transpose.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIterator2dThreadTile +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Vistitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + bool Transpose = false +> +class PredicatedTileIterator2dThreadTile; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIterator2dThreadTile { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + /// extra set of parenthesis is needed for VS compiler + struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits::value / + 8)) AccessType { + + Array storage; + + static int const kElements = ThreadMap::kElementsPerAccess; + }; + + /// Optionally this fragment can be 4x4 transposed + using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>; + static bool const transpose = Transpose_; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIterator2dThreadTile; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIterator2dThreadTile; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Base const &base) + : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset, + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile operator++(int) { + PredicatedTileIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ + + int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + + address_iterator_.set_iteration_index(access_idx); + if (address_iterator_.valid()) { + + frag_ptr[access_idx] = + *(address_iterator_.get() + pointer_offset); + } + + ++address_iterator_; + } + } + } + + if (transpose) { + Transform t; + t.transform(frag, frag); + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){ + + int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \ + s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided; + + address_iterator_.set_iteration_index(access_idx); + if (address_iterator_.valid()) { + *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + bool Transpose_ +> +class PredicatedTileIterator2dThreadTile { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static bool const Transpose = Transpose_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + Transpose + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile operator++(int) { + PredicatedTileIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIterator2dThreadTile for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + bool Transpose_ +> +class PredicatedTileIterator2dThreadTile { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static bool const Transpose = Transpose_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + Transpose + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIterator2dThreadTile; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { } + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const &base) + : params_(base) {} + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset, ///< Initial offset of threadblock + int const *indices = nullptr ///< gather/scatter indices, note no support for gather/scatter at this specialization + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a PredicatedTileIterator2dThreadTile with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIterator2dThreadTile operator++(int) { + PredicatedTileIterator2dThreadTile self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..9bf5e8586675c11bb52e2db5346ff19f489461af --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator_triangular_matrix.h @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_triangular_matrix.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorTriangularMatrix +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize register liveness +/// and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is constructed. +/// Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator is constructed. +/// Subsequent additions to logical coordinate offset may be performed but are relatively expensive. +/// +/// Vistitation order is intended to first visit a "residual" tile that may be partially full in +/// both the advance dimension and the steady-state dimension. This is assumed to be the last +/// tile in the iteration sequence. Advancing an iterator that has just been constructed moves to +/// the first tile that is full in the advance dimension and recomputes predicates. Subsequent +/// accesses may be performed without updating internal predicates and are efficient in terms of +/// live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once +/// outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing +/// the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = transform::threadblock::PredicatedTileIteratorTriangularMatrix; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + SideMode kSideMode, + FillMode kFillMode, + DiagType kDiagType, + int AccessSize = ThreadMap::kElementsPerAccess +> +class PredicatedTileIteratorTriangularMatrix; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorTriangularMatrix for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template +class PredicatedTileIteratorTriangularMatrix { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + /// Type used for internal memory accesses + using AccessType = AlignedArray::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = + PredicatedTileAccessIteratorTriangularMatrix; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorTriangularMatrix; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout) : params_(layout) { } + + CUTLASS_HOST_DEVICE + Params() { } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : address_iterator_(params.params_, pointer, extent, thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix &operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix operator++(int) { + PredicatedTileIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { address_iterator_.enable_mask(); } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { address_iterator_.get_mask(mask); } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_byte_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + + int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_byte_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorTriangularMatrix for column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + SideMode kSideMode, + FillMode kFillMode, + DiagType kDiagType, + int AccessSize +> +class PredicatedTileIteratorTriangularMatrix { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + kSideMode, + kFillMode, + kDiagType, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + } + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column()) + ) { } + + /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix operator++(int) { + PredicatedTileIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorTriangularMatrix for row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + SideMode kSideMode, + FillMode kFillMode, + DiagType kDiagType, + int AccessSize +> +class PredicatedTileIteratorTriangularMatrix { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedTileIteratorTriangularMatrix< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + kSideMode, + kFillMode, + kDiagType, + AccessSize + >; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + + friend PredicatedTileIteratorTriangularMatrix; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + + CUTLASS_HOST_DEVICE + Params() { } + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) { + + }; + }; + + +private: + + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + +public: + + /// Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const &threadblock_offset ///< Initial offset of threadblock + ): + iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row()) + ) { } + + /// Construct a PredicatedTileIteratorTriangularMatrix with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix( + Params const ¶ms, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ): PredicatedTileIteratorTriangularMatrix(params, pointer, extent, thread_id, make_Coord(0, 0)) { } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the iterator's + /// internal pointer is reverted to the first "steady state" tile. Subsequent calls + /// are lightweight and must only update the internal pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorTriangularMatrix operator++(int) { + PredicatedTileIteratorTriangularMatrix self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..df551c13f52834bfa6258104f99c7ed008342279 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/predicated_vector_access_iterator.h @@ -0,0 +1,417 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates implementing computing the addresses of loading small + vectors from the global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedVectorAccessIterator +/// +template < + /// Shape of the vector accessed by the entire threadblock + typename Shape, + /// Shape of the vector accessed by the warp + typename WarpShape, + /// Type of Element + typename Element, + /// Layout of the vector + typename Layout, + /// Number of elements for each access + int ElementsPerAccess, + /// Support residual tile + bool EnableResidualAccess = false +> +class PredicatedVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Vector access iterator specialized for vectors, e.g. scale and bias +/// Thread arrangements are for TensorOps +/// +template < + typename Shape_, + typename WarpShape_, + typename Element_, + int ElementsPerAccess, + bool EnableResidualAccess +> +class PredicatedVectorAccessIterator < + Shape_, + WarpShape_, + Element_, + layout::PitchLinear, + ElementsPerAccess, + EnableResidualAccess +> { + public: + + using Shape = Shape_; + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + +// static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kThreads = 32; + static int const kRowsPerIteration = 8; + static int const kThreadsPerRow = kThreads / kRowsPerIteration; + static int const kThreadsPerRowMask = 0x3; + static int const kIterations = WarpShape::kContiguous / (kThreadsPerRow * kElementsPerAccess); + static int const kWarpCountStrided = Shape::kStrided / WarpShape::kStrided; + + using AccessType = AlignedArray; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char *; + + private: + // + // Data members + // + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Extent of tensor + TensorCoord extent_; + + /// pointer offset of each thread + TensorCoord thread_offset_; + + /// iteration index + LongIndex iteration_; + + /// residual access + bool is_residual_; + + /// residual offset of each thread + TensorCoord residual_offset_; + + public: + /// Constructs a vector access iterator + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + /// Pointer to the start of the vector + ConstPointer pointer, + /// Extent of vector + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// ID of each participating warp + int warp_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : pointer_(reinterpret_cast( + const_cast(pointer))), + extent_(extent), + is_residual_(false) { + + + int warp_offset = (warp_id / kWarpCountStrided) * WarpShape::kContiguous; + + // Per-thread offset in logical coordinates of tensor + + thread_offset_ = threadblock_offset + TensorCoord(warp_offset, 0) + + TensorCoord((thread_id & kThreadsPerRowMask) * kElementsPerAccess, 0); + + set_iteration_index(0); + + if(EnableResidualAccess) { + // compute residual offset + typename TensorCoord::Index residual_size = extent_.contiguous() % WarpShape::kContiguous; + if (residual_size) { + is_residual_ = true; + residual_offset_ = make_Coord(residual_size, 0); + } + } + } + + /// Construct a PredicatedVectorAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + /// Pointer to start of vector + ConstPointer pointer, + /// Extent of vector + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + /// ID of each participating warp + int warp_id) + : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, + make_Coord(0, 0)) {} + + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_ = index; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + + thread_offset_ = + thread_offset_ + + TensorCoord(WarpShape::kContiguous * tile_offset.contiguous(), 0); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + return reinterpret_cast( + pointer_ + + ((thread_offset_.contiguous() + iteration_ * kThreadsPerRow * kElementsPerAccess) + * sizeof_bits::value / 8)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator &operator++() { + ++iteration_; + if(iteration_ >= kIterations) + iteration_ = 0; + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + if(EnableResidualAccess && is_residual_) { + is_residual_ = false; + thread_offset_ += residual_offset_; + } + else + add_tile_offset(TensorCoord(1, 0)); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator operator++(int) { + PredicatedVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return ((thread_offset_.contiguous() + + iteration_ * kThreadsPerRow * kElementsPerAccess) < extent_.contiguous()); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedVectorAccessIterator for row-major data. +/// +template < + typename Shape_, + typename WarpShape_, + typename Element_, + int ElementsPerAccess, + bool EnableResidualAccess +> +class PredicatedVectorAccessIterator< + Shape_, + WarpShape_, + Element_, + layout::RowMajor, + ElementsPerAccess, + EnableResidualAccess +> { + public: + + using Shape = Shape_; + using WarpShape = WarpShape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using ConstPointer = const Element *; + using NonConstPointer = typename platform::remove_const::type *; + + using UnderlyingIterator = PredicatedVectorAccessIterator< + layout::PitchLinearShape, + layout::PitchLinearShape, + Element, + layout::PitchLinear, + ElementsPerAccess, + EnableResidualAccess>; + + using AccessType = typename UnderlyingIterator::AccessType; + static int const kElementsPerAccess = UnderlyingIterator::kElementsPerAccess; + static int const kRowsPerIteration = UnderlyingIterator::kRowsPerIteration; + static int const kThreads = UnderlyingIterator::kThreads; + static int const kIterations = UnderlyingIterator::kIterations; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + ///< Pointer to the start of the vector + ConstPointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< ID of each participating warp + int warp_id, + ///< Initial offset of threadblock + TensorCoord const &threadblock_offset) + : iterator_(pointer, layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, warp_id, + layout::PitchLinearCoord(threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedVectorAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator( + ConstPointer pointer, ///< Pointer to the start of the vector + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int warp_id ///< ID of each participating warp + ) + : PredicatedVectorAccessIterator(pointer, extent, thread_id, warp_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const &tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedVectorAccessIterator operator++(int) { + PredicatedVectorAccessIterator self(*this); + operator++(); + return self; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + void advance() { + iterator_.advance(); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..1aae46988418c72a9322b7e6b47e1dfe4fadff8d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h @@ -0,0 +1,253 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Templates implementing computing the addresses of storing of small + scale and bias vectors in the shared memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// RegularScaleBiasVectorAccessIterator +/// +template +class RegularScaleBiasVectorAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularScaleBiasVectorAccessIterator { + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + /// Element type per access + static int const kElementsPerAccess = 128 / sizeof_bits::value; + static int const kThreads = Shape::kContiguous / kElementsPerAccess; + using AccessType = Array; + + private: + // + // Data members + // + + /// Internal pointer + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator( + TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias + ///< vector + int thread_id ///< ID of each participating thread + ) + : byte_offset_(0) { + // Per-thread offset in logical coordinates of tensor + int thread_offset = thread_id * kElementsPerAccess; + + // initialize pointer + pointer_ = + reinterpret_cast(scale_bias_ref.data() + thread_offset); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + char *access_byte_ptr = + reinterpret_cast(pointer_); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator &operator++() { return *this; } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator operator++(int) { + RegularScaleBiasVectorAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + // Multiply by 2 because we store scale and bias belong to the same stage + // next to each other. + add_pointer_offset(coord.contiguous() * Shape::kContiguous * 2); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for row major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularScaleBiasVectorAccessIterator< + Shape_, Element_, + layout::RowMajor> { + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + /// Underlying iterator type + using UnderlyingIterator = RegularScaleBiasVectorAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator( + TensorRef scale_bias_ref, ///< Pointer to the start of the scale and bias + ///< vector + int thread_id ///< ID of each participating thread + ) + : iterator_({scale_bias_ref.data(), scale_bias_ref.stride()}, thread_id) { + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularScaleBiasVectorAccessIterator operator++(int) { + RegularScaleBiasVectorAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb491b5a4b5f4e1b757f99110f6a9fd28675088 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator.h @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing the address computation of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template ::value* ThreadMap::kElementsPerAccess / 8> +class RegularTileAccessIterator; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h new file mode 100644 index 0000000000000000000000000000000000000000..adda9339b87865799c56baba4c3f8df580e26ac5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h @@ -0,0 +1,408 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * + ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for column major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajor, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for row major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::RowMajor, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..71c89686a71995b45f9d4cf0fd1f0fba12ca7d8a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h @@ -0,0 +1,587 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +template ::value* ThreadMap::kElementsPerAccess / 8 + > +class RegularTileAccessIteratorDirectConv; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations OFF +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_, false, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + //Do nothing + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * ThreadMap::Iterations::kStrided * + ThreadMap::Delta::kStrided * stride_ * ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps with dynamic_iterations ON +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::PitchLinear, + AdvanceRank, ThreadMap_,true, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + /// Total iterattions in the strided dimension: Dynamic value + int total_iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) / ThreadMap::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_base)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + total_iteration_strided_ = num; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < total_iteration_strided_) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous + + coord.strided() * total_iteration_strided_ * ThreadMap::Delta::kStrided * stride_ * + ThreadMap::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for column major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::ColumnMajor, + AdvanceRank, ThreadMap_, Dynamic_iterations , Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIteratorDirectConv< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_, + Dynamic_iterations>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + iterator_.set_iteration_num(num); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + ++iterator_; + + return prev; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for row major layouts +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIteratorDirectConv< + Shape_, Element_, + layout::RowMajor, + AdvanceRank, ThreadMap_, Dynamic_iterations, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIteratorDirectConv< + layout::PitchLinearShape, Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_, + Dynamic_iterations>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_num(int num) { + iterator_.set_iteration_num(num); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIteratorDirectConv operator++(int) { + RegularTileAccessIteratorDirectConv prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..e172447fa96b02e11246f5f397911841c52eff4c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h @@ -0,0 +1,821 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCongruous::value, + Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + static int const kCrosswise = Crosswise; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + ///< Number of pointers + static int const kPointerCount = + (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), + byte_offset_(0) { + layout::PitchLinearCoord thread_offset_base = + ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + + layout::PitchLinearCoord{ + 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast( + ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + AccessType *access_ptr = pointer_[iteration_strided_ & 1]; + int stride_idx = (iteration_strided_ & ~1); + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * Shape::kContiguous * Layout::kFactor + + coord.strided() * Shape::kStrided * stride_ * + Layout::kElementsPerAccess / Layout::kFactor); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::RowMajorTensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for crosswise arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + static int const kCrosswise = Crosswise; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(!(ThreadMap::Delta::kContiguous % kCrosswise), + "kCrosswise is the smallest unit in the contiguous dimension " + "for shared memory swizzling."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + /// Number of pointers + /// + /// Note:TN kblock32 layouts only needs 1 pointer, but strangely + /// reducing pointer count hurts perfomrnace + static int const kPointerCount = + (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Total number of sections. The memory is divided into stages. One stage + /// can store one tile. Stage is divided into sections. Interleaved layout + /// can have multiple sections in a stage. The rest layout only has one section + /// in a stage. + int sections_; + + /// Sections that a stage has + int sections_per_stage_; + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : sections_(ref.stride(0) / kCrosswise), + sections_per_stage_(Shape::kContiguous / kCrosswise), + // stride_ = kCrosswise x sections_ x kFactor + stride_(ref.stride(0) * Layout::kFactor / Layout::kElementsPerAccess), + byte_offset_(0) { + layout::PitchLinearCoord thread_offset_base = + ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + + layout::PitchLinearCoord{ + 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + // initialize pointer + pointer_[i] = reinterpret_cast(ref.data()) + + ref.offset(thread_offset_in_threadblock_tile) / + Layout::kElementsPerAccess; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + AccessType *access_ptr = pointer_[iteration_strided_ & 1]; + int stride_idx = (iteration_strided_ & ~1); + + int access_offset = + stride_idx * ThreadMap::Delta::kStrided * stride_ / Layout::kFactor + + // kCrosswise elements in the contiguous dimension would span to a + // shared memory cache line. + iteration_contiguous_ * (ThreadMap::Delta::kContiguous / kCrosswise) * + Layout::TileShape::kContiguous; + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_strided_ == ThreadMap::Iteration::kStrided) + // which means we enter the next section. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset(coord.contiguous() * sections_per_stage_ * stride_ * + ThreadMap::kElementsPerAccess / sections_ + + coord.strided() * Shape::kStrided * stride_ * + Layout::kElementsPerAccess / Layout::kFactor); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..b55f841eee2e09aec8af5c8ec945a1997705c9f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h @@ -0,0 +1,1532 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing computing the addresses of storing of tiles + from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous64b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 64; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 64b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous64b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous64b, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous64b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous64b, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for crosswise arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicand64bCrosswise, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 64; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 64b"); + + ///< Number of pointers - two pointers are needed if making more than 4 iterations along + ///< strided dimension + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 4 ? 2 : 1); + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_[Detail::kPointerCount]; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / ThreadMap::kElementsPerAccess) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data()); + + byte_offset_[0] = ref.offset(thread_offset_in_threadblock_tile) * sizeof(Element); + + if (Detail::kPointerCount == 2) { + byte_offset_[1] = byte_offset_[0] ^ 8; + } + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + pointer_ += pointer_offset / ThreadMap::kElementsPerAccess; + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + + // Map the logical contiguous and strided access to the internal swizzled structure. + int uniform_offset = (iteration_strided_ & 0x3) * stride_ + (iteration_strided_ >> 3) * 16 + stride_ * ThreadMap::Delta::kContiguous * iteration_contiguous_; + + char *access_byte_ptr = reinterpret_cast(pointer_ + uniform_offset); + + int byte_offset; + + // This iterator may require two byte offsets if it must load more than 8 rows (or 2 iterations) + // in the strided dimension + if (Detail::kPointerCount == 2 && (iteration_strided_ & 0x4)) { + byte_offset = byte_offset_[1]; + } + else { + byte_offset = byte_offset_[0]; + } + + return reinterpret_cast(access_byte_ptr + byte_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset(coord.strided() * Shape::kStrided + coord.contiguous() * Shape::kContiguous * stride_); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicand64bCrosswise, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicand64bCrosswise, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicand64bCrosswise; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicand64bCrosswise, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous128b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int access_offset = iteration_strided_ * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous128b, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous128b, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous128b; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous128b, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCrosswise128x4, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::TensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + static_assert(ThreadMap::kThreads / 32 > 1, + "This tile iterator requires at least two warps."); + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * + ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128b"); + + ///< Number of pointers + static int const kPointerCount = 1; + }; + + + static_assert(!(ThreadMap::Iterations::kStrided % 2), "This iterator requires at least two iterations along the strided dimension"); + + /// Element type per access + using AccessType = Array; + + private: + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType *pointer_; + + /// Internal byte offset + Index byte_offset_; + + /// Iteration in the contiguous dimension + int iteration_contiguous_; + + /// Iteration in the strided dimension + int iteration_strided_; + + public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + stride_(ref.stride(0) / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = thread_offset_base; + + // initialize pointer + pointer_ = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + + set_iteration_index(0); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + + iteration_contiguous_ = index % ThreadMap::Iterations::kContiguous; + iteration_strided_ = index / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + + AccessType *access_ptr = pointer_; + + int offset_c = (iteration_contiguous_ * ThreadMap::Delta::kContiguous + (iteration_strided_ & 1) * 2); + int offset_s = (iteration_strided_ / 2) * 8; + + int access_offset = offset_c * stride_ + offset_s; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + + return reinterpret_cast(access_byte_ptr + byte_offset_); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iteration_contiguous_; + + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) + return *this; + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + iteration_contiguous_ = 0; + ++iteration_strided_; + + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + iteration_strided_ = 0; + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + + RegularTileAccessIterator prev(*this); + + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + + add_pointer_offset( + coord.contiguous() * Shape::kContiguous * stride_ + + coord.strided() * Shape::kStrided * Layout::kElementsPerAccess); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCrosswise128x4, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise128x4, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileAccessIterator { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCrosswise128x4; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileAccessIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise128x4, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + using AccessType = typename UnderlyingIterator::AccessType; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileAccessIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): + iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { iterator_.set_iteration_index(index); } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + ++iterator_; + + return prev; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..be07e43f6f45132f79d95afb95714c4392149b66 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 +> +class RegularTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h new file mode 100644 index 0000000000000000000000000000000000000000..6c186ce3fe0650c3f8927d84f1983916d9d1867f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -0,0 +1,552 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for pitch-linear. This one is used by 2-stage SIMT kernels +/// and sparse tensor core meta data. +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using AccessType = AlignedArray; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the contiguous or strided dimensions."); + +private: + + // + // Types + // + + // + // Data members + // + + /// Pointer to memory + uint8_t *pointer_; + + /// Stride quantity + StrideIndex stride_; + + /// Amount to increment pointer along strided dimension + Index increment_strided_; + + /// Amount to advance pointer between tiles + Index increment_advance_; + +public: + + CUTLASS_DEVICE + RegularTileIterator(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } + + CUTLASS_DEVICE + RegularTileIterator( + TensorRef const &ref, + int thread_idx + ): + pointer_(reinterpret_cast(ref.data()) + (ref.offset(ThreadMap::initial_offset(thread_idx)) * sizeof_bits::value / 8)) { + + stride_ = ref.stride()[0]; + increment_strided_ = (ref.stride()[0] * sizeof_bits::value) * ThreadMap::Delta::kStrided / 8; + + increment_advance_ = + (kAdvanceRank == 0 ? + Shape::kContiguous * sizeof_bits::value / 8 : + Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8)); + } + + /// Loads a fragment + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType const *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + load_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + access_ptr[c * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess] = frag_ptr[idx]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + store_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + pointer_ += increment_advance_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator--() { + pointer_ -= increment_advance_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + int offset = sizeof_bits::value * + (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; + add_pointer_offset(offset); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { +#if 0 + AccessType *access_ptr = pointer_[iteration_strided_ & 1]; + int stride_idx = (iteration_strided_ & ~1); + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + iteration_contiguous_ * ThreadMap::Delta::kContiguous / + ThreadMap::kElementsPerAccess; + + char *access_byte_ptr = + reinterpret_cast(access_ptr + access_offset); + return reinterpret_cast(access_byte_ptr + byte_offset_); +#endif + return reinterpret_cast(pointer_); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for row major +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using Underlying = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + kAlignment + >; + + using AccessType = typename Underlying::AccessType; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator() { } + + CUTLASS_DEVICE + RegularTileIterator( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return iterator_.get(); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for pitch-linear +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using Underlying = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap + >; + + using AccessType = typename Underlying::AccessType; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator() { } + + CUTLASS_DEVICE + RegularTileIterator( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return iterator_.get(); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h new file mode 100644 index 0000000000000000000000000000000000000000..5ed2e7fdd08ceafe772c97ab90f915c2268cabbb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear_2dthreadtile.h @@ -0,0 +1,509 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int Alignment = sizeof_bits::value * ThreadMap::kElementsPerAccess / 8 +> +class RegularTileIterator2dThreadTile; + + +/// Regular tile iterator specialized for pitch-linear + 2d thread-tiled threadmapping +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator2dThreadTile { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the contiguous or strided dimensions."); + +private: + + // + // Types + // + + using AccessType = AlignedArray; + + // + // Data members + // + + /// Pointer to memory + uint8_t *pointer_; + + /// Stride quantity + StrideIndex stride_; + + /// Amount to increment pointer along strided dimension + LongIndex increment_strided_; + + /// Amount to advance pointer between tiles + LongIndex increment_advance_; + +public: + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile(): pointer_(nullptr), increment_strided_(0), increment_advance_(0) { } + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile( + TensorRef const &ref, + int thread_idx, + int interleave + ){ + + TensorCoord t = ThreadMap::initial_offset(thread_idx); + long int offset = t[0] * interleave + t[1] * ref.stride()[0]/interleave; + pointer_ = reinterpret_cast(ref.data() + offset); + + stride_ = ref.stride()[0] / interleave; + increment_strided_ = (ref.stride()[0] * sizeof_bits::value / 8) * ThreadMap::Delta::kStrided / interleave; + + increment_advance_ = + (kAdvanceRank == 0 ? + Shape::kContiguous * sizeof_bits::value / 8 : + Shape::kStrided * (ref.stride()[0] * sizeof_bits::value / 8) / interleave); + } + + /// Loads a fragment + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + uint8_t const *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType const *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[idx] = access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + load_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + uint8_t *byte_pointer = pointer_ + pointer_offset * sizeof_bits::value / 8; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int idx = c + s * ThreadMap::Iterations::kContiguous; + access_ptr[c * ThreadMap::Delta::kContiguous / ThreadMap::ThreadAccessShape::kStrided] = frag_ptr[idx]; + } + + if (s + 1 < ThreadMap::Iterations::kStrided) { + byte_pointer += increment_strided_; + } + } + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + store_with_pointer_offset( + frag, + tile_offset.contiguous() * Shape::kContiguous + tile_offset.strided() * Shape::kStrided * stride_ + ); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator++() { + pointer_ += increment_advance_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator--() { + pointer_ -= increment_advance_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + int offset = sizeof_bits::value * + (coord.contiguous() * Shape::kContiguous + coord.strided() * Shape::kStrided * stride_) / 8; + add_pointer_offset(offset); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorInterleaved<4>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + + using Underlying = RegularTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + kAlignment + >; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile() { } + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx, 4) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.column(), tile_offset.row()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Regular tile iterator specialized for interleaved layout + 2d thread-tiled threadmapping +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator2dThreadTile, AdvanceRank, ThreadMap_, Alignment> { +public: + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorInterleaved<4>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Fragment = Array; + using PitchLinearThreadMap = PitchLinearStripminedThreadMap< layout::PitchLinearShape, + ThreadMap::kThreads, ThreadMap::ThreadAccessShape::kCount >; + + + using Underlying = RegularTileIterator2dThreadTile< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap + >; + + static_assert(kAdvanceRank == 0 || kAdvanceRank == 1, + "Advance rank may only be along the row or column dimensions."); + +private: + + Underlying iterator_; + +public: + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile() { } + + CUTLASS_DEVICE + RegularTileIterator2dThreadTile( + TensorRef const &ref, + int thread_idx + ): + iterator_({ref.data(), ref.stride()}, thread_idx, 4) { + + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag, TensorCoord const & tile_offset) { + iterator_.load_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(Fragment &frag) { + iterator_.load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, TensorCoord const & tile_offset) { + iterator_.store_with_pointer_offset(frag, {tile_offset.row(), tile_offset.column()}); + } + + /// Stores a fragment + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + iterator_.store_with_pointer_offset(frag, 0); + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator++() { + ++iterator_; + return *this; + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + RegularTileIterator2dThreadTile &operator--() { + --iterator_; + return *this; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..723f328d976fc170d198282823e3da6876ec1ba6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h @@ -0,0 +1,1107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing storing of tiles from pitch-linear rank=2 tensors. +*/ + +#pragma once + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCongruous::value, + Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// This iterator is specialized for an access size that is 128 bits in length. + static int const kAccessSizeInBits = 128; + + static_assert( + sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + +private: + + /// Element type per access + using AccessType = Array; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = RegularTileAccessIterator; + +private: + + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : address_iterator_(ref, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + address_iterator_.add_tile_offset({0, 1}); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + address_iterator_.add_tile_offset(coord); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + load_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, Index byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char const *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType const *access_ptr = reinterpret_cast(byte_ptr); + + frag_ptr[access_idx] = *access_ptr; + ++address_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; + ++address_iterator_; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::RowMajorTensorOpMultiplicandCongruous::value, + Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCongruous::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for crosswise arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>; + + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + + private: + /// Element type per access + using AccessType = Array; + + public: + /// Fragment object to be loaded or stored + using Fragment = + Array; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = RegularTileAccessIterator; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : address_iterator_(ref, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + address_iterator_.add_tile_offset({1, 0}); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + address_iterator_.add_tile_offset(coord); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + address_iterator_.set_iteration_index(0); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); + ++address_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + store_with_byte_offset(frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, Index byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *byte_ptr = reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType *access_ptr = reinterpret_cast(byte_ptr); + + *access_ptr = frag_ptr[access_idx]; + ++address_iterator_; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator::value, Crosswise>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Crosswise>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::TensorOpMultiplicandCrosswise::value, + Crosswise>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for k interleaved arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandRowMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + /// This iterator is specialized for an access size that is 128 bits in + /// length. + static int const kAccessSizeInBits = 128; + + static_assert(sizeof_bits::value * ThreadMap::kElementsPerAccess == + kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + }; + + private: + + /// Element type per access + using AccessType = Array; + + public: + /// Fragment object to be loaded or stored + using Fragment = + Array; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = RegularTileAccessIterator; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : address_iterator_(ref, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + address_iterator_.add_pointer_offset(Shape::kCount); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + address_iterator_.add_pointer_offset(coord.contiguous() * Shape::kCount); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + address_iterator_.set_iteration_index(0); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + frag_ptr[access_idx] = *(address_iterator_.get() + pointer_offset); + ++address_iterator_; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx]; + ++address_iterator_; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for k interleaved arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// + +template +class RegularTileIterator< + Shape_, Element_, + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>, + AdvanceRank, ThreadMap_, Alignment> { + + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::TensorOpMultiplicandColumnMajorInterleaved::value, + InterleavedK>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + cutlass::MatrixShape, + Element, + layout::TensorOpMultiplicandRowMajorInterleaved::value, InterleavedK>, + (kAdvanceRank == 1 ? 0 : 1), + ThreadMap + >; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.strided(), coord.contiguous()}); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h new file mode 100644 index 0000000000000000000000000000000000000000..53121c6114cc3675e4d97f9da65d3ecb58e46d62 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op_sm70.h @@ -0,0 +1,1460 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses and visits the last "residue" tile + first, with the objective of minimizing predicate mask updates during steady-state operation. + + A precomputed "Params" object minimizes the amount of state that must be stored in registers, + and integer addition is used to advance the pointer through memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm70.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::VoltaTensorOpMultiplicandCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::VoltaTensorOpMultiplicandCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// This iterator is specialized for an access size that is 128 bits in length. + static int const kAccessSizeInBits = 128; + + static_assert( + sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + ///< Number of pointers + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + +private: + + /// Element type per access + using AccessType = Array; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType * pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + + // This is the offset of a thread within a threadblock tile for a specific pointer + // (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset( + coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess + ); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandCongruous::value>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorVoltaTensorOpMultiplicandCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandCongruous::value>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; +/// Tile iterator specialized for congruous arrangements for TensorOps +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::VoltaTensorOpMultiplicandBCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::VoltaTensorOpMultiplicandBCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + /// This iterator is specialized for an access size that is 128 bits in length. + static int const kAccessSizeInBits = 128; + + static_assert( + sizeof_bits::value * ThreadMap::kElementsPerAccess == kAccessSizeInBits, + "This iterator requires a policy whose access size is 128bs"); + + ///< Number of pointers + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + }; + + +private: + + /// Element type per access + using AccessType = Array; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + // + // Data members + // + + /// Stride value + StrideIndex stride_; + + /// Internal pointer to first access of tile + AccessType * pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): stride_(ref.stride(0) / Layout::kElementsPerAccess), byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + + // This is the offset of a thread within a threadblock tile for a specific pointer + // (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + layout::PitchLinearCoord{0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast(ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + add_pointer_offset((kAdvanceRank ? Shape::kStrided * stride_ * Layout::kElementsPerAccess : Shape::kContiguous)); + + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset( + coord.contiguous() * Shape::kContiguous / ThreadMap::kElementsPerAccess + + coord.strided() * Shape::kStrided * stride_ * Layout::kElementsPerAccess + ); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + frag_ptr[access_idx] = *reinterpret_cast(access_byte_ptr + byte_offset_); + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / ThreadMap::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[s & 1]; + int stride_idx = (s & ~1); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + int access_offset = stride_idx * ThreadMap::Delta::kStrided * stride_ + + c * ThreadMap::Delta::kContiguous / ThreadMap::kElementsPerAccess + + vec_pointer_offset; + + int access_idx = c + s * ThreadMap::Iterations::kContiguous; + + char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + *reinterpret_cast(access_byte_ptr + byte_offset_) = frag_ptr[access_idx]; + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandBCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandBCongruous::value>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major congruous TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, + Element_, + layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>, + AdvanceRank, + ThreadMap_, + Alignment> { +public: + + static_assert(AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorVoltaTensorOpMultiplicandBCongruous::value>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, + Element, + layout::VoltaTensorOpMultiplicandBCongruous::value>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap_>; + +public: + + /// Fragment object to be loaded or stored + using Fragment = Array; + +private: + + /// Underlying iterator + UnderlyingIterator iterator_; + +public: + + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator( + TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ): iterator_({ref.data(), ref.stride()}, thread_id) { + + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset( + Fragment const &frag, + Index pointer_offset) { + + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } +}; + + +/// Tile iterator specialized for crosswise arrangements for TensorOps. +/// +/// Volta TN SMEM layout is a little diffrent: +/// Crosseised elements will be stored in a line, while contiguous elements +/// sre stored in line-by-line. +/// Padding is used to reduce SMEM bank conflicts. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator< + Shape_, Element_, + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape_::kContiguous>, + AdvanceRank, ThreadMap_, Alignment> { + + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape::kContiguous>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Internal details made public to facilitate introspection + struct Detail { + + ///< Number of pointers + static int const kPointerCount = (ThreadMap::Iterations::kStrided > 1 ? 2 : 1); + + /// Iterations for the kElementsPerAccess of ThreadMap + static int const kIterarionsPerAccess = + ThreadMap::kElementsPerAccess / Layout::kElementsPerAccess; + + /// Contiguous elements per line + static int const kContiguousElementsPerLine = 4; + }; + + private: + /// Element type per access + using AccessType = Array; + + public: + /// Fragment object to be loaded or stored + using Fragment = + Array; + + private: + // + // Data members + // + + /// The crosswised elements will be stored in a line. + /// line_size is size of crosswised dimension plus padding. + /// in units of AccessType + Index line_size; + + /// Internal pointer to first access of tile + AccessType *pointer_[Detail::kPointerCount]; + + /// Internal byte offset + Index byte_offset_; + + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : line_size(ref.stride(0) * Detail::kContiguousElementsPerLine / Layout::kElementsPerAccess), + byte_offset_(0) { + + layout::PitchLinearCoord thread_offset_base = + ThreadMap::initial_offset(thread_id); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Detail::kPointerCount; ++i) { + // This is the offset of a thread within a threadblock tile for a specific + // pointer (units of elements) + layout::PitchLinearCoord thread_offset_in_threadblock_tile = + thread_offset_base + + layout::PitchLinearCoord{ + 0, ThreadMap::Detail::WarpThreadArrangement::kStrided * i}; + + // initialize pointer + pointer_[i] = reinterpret_cast( + ref.data() + ref.offset(thread_offset_in_threadblock_tile)); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_offset_ += pointer_offset * sizeof(Element); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + // (Shape::kContiguous/Layout::kElementsPerAccess)* + // line_size * Layout::kElementsPerAccess + add_pointer_offset(Shape::kContiguous * line_size); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + add_pointer_offset((coord.contiguous() * (Shape::kContiguous / Layout::kElementsPerAccess) * + line_size + coord.strided() * Shape::kStrided) * + Layout::kElementsPerAccess); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + AccessType *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + AccessType *access_ptr = pointer_[(s & 1) ^ (s / 2)]; + + access_ptr += 16 * (s / 2); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { + + int access_offset = + c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + + vec_pointer_offset + i * line_size; + + int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * + Detail::kIterarionsPerAccess + i; + + char const *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + frag_ptr[access_idx] = *reinterpret_cast( + access_byte_ptr + byte_offset_); + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + AccessType const *frag_ptr = reinterpret_cast(&frag); + + Index vec_pointer_offset = pointer_offset / Layout::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + + AccessType *access_ptr = pointer_[(s & 1) ^ ((s >> 1) & 1)]; + + access_ptr += 16 * (s / 2) + vec_pointer_offset; + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Detail::kIterarionsPerAccess; ++i) { + + int access_offset = + c * ThreadMap::Delta::kContiguous / Detail::kContiguousElementsPerLine * line_size + i * line_size; + + int access_idx = (c + s * ThreadMap::Iterations::kContiguous) * + Detail::kIterarionsPerAccess + i; + + char *access_byte_ptr = reinterpret_cast(access_ptr + access_offset); + + *reinterpret_cast(access_byte_ptr + byte_offset_) = + frag_ptr[access_idx]; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for column-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator::value, Shape_::kRow>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for column-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajorVoltaTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kRow>; + static int const kAdvanceRank = AdvanceRank; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape::kRow>, + (kAdvanceRank == 0 ? 0 : 1), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.row(), coord.column()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile Iterator specialized for row-major crosswise TensorOp formats. +/// +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int Alignment +> +class RegularTileIterator::value, Shape_::kColumn>, + AdvanceRank, ThreadMap_, Alignment> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for row-major iterator may along advance along the " + "columns(rank=0) or rows(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajorVoltaTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kColumn>; + static int const kAdvanceRank = AdvanceRank; + static int const kAlignment = Alignment; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using ThreadMap = ThreadMap_; + + /// Underlying iterator type + using UnderlyingIterator = RegularTileIterator< + layout::PitchLinearShape, Element, + layout::VoltaTensorOpMultiplicandCrosswise::value, + Shape::kColumn>, + (kAdvanceRank == 0 ? 1 : 0), ThreadMap_>; + + public: + /// Fragment object to be loaded or stored + using Fragment = Array; + + private: + /// Underlying iterator + UnderlyingIterator iterator_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + RegularTileIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : iterator_({ref.data(), ref.stride()}, thread_id) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Adds a tile offset + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + iterator_.add_tile_offset({coord.column(), coord.row()}); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator &operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + RegularTileIterator operator++(int) { + RegularTileIterator prev(*this); + ++iterator_; + + return prev; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { load_with_pointer_offset(frag, 0); } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..8e5d181c177b2ad6627c927ae4ad3fb9c99a96d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/threadblock/vector_iterator.h @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template wraps the vector access iterator concept to load whole vector from tensors in + memory. This is typically used for per-channel scale and bias in convolution kernels. +*/ + +#pragma once + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class VectorIterator { +public: + using VectorAccessIterator = VectorAccessIterator_; + + using Shape = typename VectorAccessIterator::Shape; + using Element = typename VectorAccessIterator::Element; + using Layout = typename VectorAccessIterator::Layout; + using TensorCoord = typename Layout::TensorCoord; + using AccessType = typename VectorAccessIterator::AccessType; + using TensorRef = typename VectorAccessIterator::TensorRef; + using Index = typename VectorAccessIterator::Index; + using LongIndex = typename VectorAccessIterator::LongIndex; + + static int const kElementsPerAccess = VectorAccessIterator::kElementsPerAccess; + static int const kRowsPerIteration = VectorAccessIterator::kRowsPerIteration; + static int const kThreads = VectorAccessIterator::kThreads; + static int const kIterations = VectorAccessIterator::kIterations; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, kElementsPerAccess * kIterations>; + +private: + + /// Internal state + VectorAccessIterator vector_access_iterator_; + +public: + + /// Constructor + CUTLASS_HOST_DEVICE + VectorIterator( + Element const *ptr, + TensorCoord extent, + int thread_idx, + int warp_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + vector_access_iterator_(ptr, extent, thread_idx, warp_idx, threadblock_offset) { } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + VectorIterator &operator++() { + vector_access_iterator_.advance(); + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + VectorIterator operator++(int) { + VectorIterator self(*this); + operator++(); + return self; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) { + + frag.clear(); + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < kIterations; ++c) { + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[c], + vector_access_iterator_.get() + pointer_offset, + vector_access_iterator_.valid() + ); + + ++vector_access_iterator_; + } +// } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) { + vector_access_iterator_.set_iteration_index(0); + load_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void advance() { + vector_access_iterator_.advance(); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..b27b77f9b697476ed54a019cd94120561371ebd1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/transform/warp/vector_fragment_iterator.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of a warp vector + that participate in one warp-level mma operation. + + Typically, this is used to access the scale/bias fragment of a warp-level mma operation. + The scale/bias vector is then partitioned into smaller fragments that can be fed into + next warp-level mma operation. + + This iterator is necessary to accomplish warp-level mma fusion where the scale/bias vector is + applied to the multiplicand for the next mma. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace transform { +namespace warp { + + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the input fragment tile shape (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + //// Number of elements per access when loading fragment + int ElementsPerAccess> +class VectorFragmentIterator; + + +// Partial specialization for PitchLinear layout tile + +template < + /// Size of the input fragment vector shape (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + //// Number of elements per access when loading fragment + int ElementsPerAccess> +class VectorFragmentIterator { + public: + + /// Size of the input threadblock tile shape (concept: MatrixShape) + using Shape = Shape_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::PitchLinear; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Number of participating threads + static int const kThreads = 32; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kRowsPerIteration = 8; + static int const kColumnsPerAccess = 8; + static int const kElementsPerIteration = kRowsPerIteration * InstructionShape::kK / kThreads; + static int const kAccessPerIteration = kElementsPerIteration / kElementsPerAccess; + + /// Number of iterations + using Iterations = MatrixShape; + +public: + + // + // Derived quantities + // + // All fragments have kElementsPerAccess scale followed by bias + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one iteration of the iterator. + using Fragment = Array; + + /// Input threadblock fragment tile + using ThreadblockFragment = Array; + +private: + + /// Internal access type + using AccessType = Array; + +private: + // + // Data members + // + + /// Input threadblock fragment tile + AccessType const *iterator_; + + /// Internal index + int index_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) + : iterator_(reinterpret_cast(&threadblock_frag)), + index_(0) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + + if(index_ >= Iterations::kColumn) + index_ = 0; + } + + /// Increments + CUTLASS_HOST_DEVICE + VectorFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + CUTLASS_HOST_DEVICE + void set_index(int idx) { + index_ = idx; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int r = 0; r < Iterations::kRow; r++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kAccessPerIteration; i++) { + + frag_ptr[i * Iterations::kRow + r].clear(); + frag_ptr[i * Iterations::kRow + r] = iterator_[index_ * kAccessPerIteration + i]; + } + } + } + +}; + +// Partial specialization for Row-Major layout tile + +template < + /// Size of the input fragment tile shape (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + //// Number of elements per access when loading fragment + int ElementsPerAccess> +class VectorFragmentIterator { + public: + + /// Size of the input threadblock tile shape (concept: MatrixShape) + using Shape = Shape_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Underlying iterator + using Base = VectorFragmentIterator< + layout::PitchLinearShape, Element, + layout::PitchLinear, InstructionShape, ElementsPerAccess>; + + + public: + + // + // Derived quantities + // + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one iteration of the iterator. + using Fragment = typename Base::Fragment; + + /// Input threadblock fragment tile + using ThreadblockFragment = typename Base::ThreadblockFragment; + + private: + /// Underlying iterator + Base iterator_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + VectorFragmentIterator(ThreadblockFragment const &threadblock_frag) + : iterator_(threadblock_frag) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + iterator_.add_offset(index_offset); + } + + /// Increments + CUTLASS_HOST_DEVICE + VectorFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + CUTLASS_HOST_DEVICE + void set_index(int idx) { + iterator_.set_index(idx); + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + iterator_.load(frag); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace conv +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/uint128.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/uint128.h new file mode 100644 index 0000000000000000000000000000000000000000..68896d6b60767221fd41421a0d3fdf75392c3604 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/uint128.h @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. +*/ +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#include +#include +#include +#include +#endif + + +/// Optionally enable GCC's built-in type +#if (defined(__x86_64) || defined (__aarch64__)) && !(defined(__CUDA_ARCH__) && ((__CUDACC_VER_MAJOR__ <= 10) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ <= 4)))) && defined(__GNUC__) +#define CUTLASS_UINT128_NATIVE +#elif !defined(__CUDA_ARCH__) +// No custom support for 128b arithmetic on device +#if defined(_MSC_VER) && defined(_M_AMD64) +#define CUTLASS_INT128_ARITHMETIC +#include +#if _MSC_VER >= 1920 && !defined(__CUDA_ARCH__) +#define CUTLASS_INT128_ARITHMETIC_DIV +#include +#endif +#endif +#endif + +namespace cutlass { + +///! Unsigned 128b integer type +struct alignas(16) uint128_t +{ + /// Size of one part of the uint's storage in bits + static constexpr int storage_bits_ = 64; + + struct hilo + { + uint64_t lo; + uint64_t hi; + }; + + // Use a union to store either low and high parts or, if present, a built-in 128b integer type. + union { + struct hilo hilo_; + +#if defined(CUTLASS_UINT128_NATIVE) + unsigned __int128 native; +#endif // defined(CUTLASS_UINT128_NATIVE) + }; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + uint128_t() : hilo_{0, 0} {} + + /// Constructor from uint64 + CUTLASS_HOST_DEVICE + uint128_t(uint64_t lo_) : hilo_{lo_, 0} {} + + /// Constructor from two 64b unsigned integers + CUTLASS_HOST_DEVICE + uint128_t(uint64_t lo_, uint64_t hi_) : hilo_{lo_, hi_} {} + + /// Optional constructor from native value +#if defined(CUTLASS_UINT128_NATIVE) + uint128_t(unsigned __int128 value) : native(value) { } +#endif + + /// Lossily cast to uint64 + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const + { + return hilo_.lo; + } + + CUTLASS_HOST_DEVICE + static void exception() + { +#if defined(__CUDA_ARCH__) + asm volatile (" brkpt;\n"); +#else + // throw std::runtime_error("Not yet implemented."); + abort(); +#endif + } + + /// Add + CUTLASS_HOST_DEVICE + uint128_t operator+(uint128_t const& rhs) const + { + uint128_t y{}; +#if defined(CUTLASS_UINT128_NATIVE) + y.native = native + rhs.native; +#else + y.hilo_.lo = hilo_.lo + rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi + rhs.hilo_.hi + (y.hilo_.lo < hilo_.lo); +#endif + return y; + } + + /// Subtract + CUTLASS_HOST_DEVICE + uint128_t operator-(uint128_t const& rhs) const + { + uint128_t y{}; +#if defined(CUTLASS_UINT128_NATIVE) + y.native = native - rhs.native; +#else + y.hilo_.lo = hilo_.lo - rhs.hilo_.lo; + y.hilo_.hi = hilo_.hi - rhs.hilo_.hi - (rhs.hilo_.lo && y.hilo_.lo > hilo_.lo); +#endif + return y; + } + + /// Multiply by unsigned 64b integer yielding 128b integer + CUTLASS_HOST_DEVICE + uint128_t operator*(uint64_t const& rhs) const + { + uint128_t y{}; +#if defined(CUTLASS_UINT128_NATIVE) + y.native = native * rhs; +#elif defined(CUTLASS_INT128_ARITHMETIC) + // Multiply by the low part + y.hilo_.lo = _umul128(hilo_.lo, rhs, &y.hilo_.hi); + + // Add the high part and ignore the overflow + uint64_t overflow{0}; + y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); +#else + CUTLASS_UNUSED(rhs); + exception(); +#endif + return y; + } + + /// Divide 128b operation by 64b operation yielding a 64b quotient + CUTLASS_HOST_DEVICE + uint64_t operator/(uint64_t const& divisor) const + { + uint64_t quotient{0}; +#if defined(CUTLASS_UINT128_NATIVE) + quotient = uint64_t(native / divisor); +#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) + // implemented using MSVC's arithmetic intrinsics + uint64_t remainder{0}; + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + CUTLASS_UNUSED(divisor); + exception(); +#endif + return quotient; + } + + /// Divide 128b operation by 64b operation yielding a 64b quotient + CUTLASS_HOST_DEVICE + uint64_t operator%(uint64_t const& divisor) const + { + uint64_t remainder{0}; +#if defined(CUTLASS_UINT128_NATIVE) + remainder = uint64_t(native % divisor); +#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) + // implemented using MSVC's arithmetic intrinsics + (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + CUTLASS_UNUSED(divisor); + exception(); +#endif + return remainder; + } + + /// Computes the quotient and remainder in a single method. + CUTLASS_HOST_DEVICE + uint64_t divmod(uint64_t &remainder, uint64_t divisor) const + { + uint64_t quotient{0}; +#if defined(CUTLASS_UINT128_NATIVE) + quotient = uint64_t(native / divisor); + remainder = uint64_t(native % divisor); +#elif defined(CUTLASS_INT128_ARITHMETIC_DIV) + // implemented using MSVC's arithmetic intrinsics + quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); +#else + CUTLASS_UNUSED(remainder); + CUTLASS_UNUSED(divisor); + exception(); +#endif + return quotient; + } + + /// Left-shifts a 128b unsigned integer + CUTLASS_HOST_DEVICE + uint128_t operator<<(int sh) const + { + if (sh == 0) { + return *this; + } + else if (sh >= storage_bits_) { + return uint128_t(0, hilo_.lo << (sh - storage_bits_)); + } + else { + return uint128_t( + (hilo_.lo << sh), + (hilo_.hi << sh) | uint64_t(hilo_.lo >> (storage_bits_ - sh)) + ); + } + } + + /// Right-shifts a 128b unsigned integer + CUTLASS_HOST_DEVICE + uint128_t operator>>(int sh) const + { + if (sh == 0) { + return *this; + } + else if (sh >= storage_bits_) { + return uint128_t((hilo_.hi >> (sh - storage_bits_)), 0); + } + else { + return uint128_t( + (hilo_.lo >> sh) | (hilo_.hi << (storage_bits_ - sh)), + (hilo_.hi >> sh) + ); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/uint256.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/uint256.h new file mode 100644 index 0000000000000000000000000000000000000000..3657853557ebccfd6be63ce6ba0fa4d69880d649 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/uint256.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Defines an unsigned 256b integer. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#include +#include +#include +#include +#endif +#include "cutlass/uint128.h" + +namespace cutlass { + +///! Unsigned 256b integer type +struct alignas(32) uint256_t { + /// Size of one part of the uint's storage in bits + static constexpr int storage_bits_ = 128; + + struct hilo { + uint128_t lo; + uint128_t hi; + }; + + // Use a union to store either low and high parts. + union { + struct hilo hilo_; + }; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + uint256_t() : hilo_{uint128_t{}, uint128_t{}} {} + + /// Constructor from uint128 + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_) : hilo_{lo_, uint128_t{}} {} + + /// Constructor from two 128b unsigned integers + CUTLASS_HOST_DEVICE + uint256_t(uint128_t lo_, uint128_t hi_) : hilo_{lo_, hi_} {} + + /// Lossily cast to uint128_t + CUTLASS_HOST_DEVICE + explicit operator uint128_t() const { + return hilo_.lo; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/version.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/version.h new file mode 100644 index 0000000000000000000000000000000000000000..57a73a5fbb41a22ed5e44743c84fa1bbbe0b0075 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/version.h @@ -0,0 +1,80 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#define CUTLASS_MAJOR 4 +#define CUTLASS_MINOR 2 +#define CUTLASS_PATCH 1 + +#ifdef CUTLASS_VERSIONS_GENERATED +#include "cutlass/version_extended.h" +#else +#define CUTLASS_BUILD 0 +#define CUTLASS_REVISION "" +#endif + +#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) + +namespace cutlass { + + inline constexpr uint32_t getVersion() { + return CUTLASS_VERSION; + } + inline constexpr uint32_t getVersionMajor() { + return CUTLASS_MAJOR; + } + inline constexpr uint32_t getVersionMinor() { + return CUTLASS_MINOR; + } + inline constexpr uint32_t getVersionPatch() { + return CUTLASS_PATCH; + } + inline constexpr uint32_t getVersionBuild() { + return CUTLASS_BUILD + 0; + } + + inline std::string getVersionString() { + std::string version = "@CUTLASS_VERSION@"; + if (getVersionBuild()) { + version += "." + std::to_string(getVersionBuild()); + } + return version; + } + + inline std::string getGitRevision() { + return "@CUTLASS_REVISION@"; + } + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h new file mode 100644 index 0000000000000000000000000000000000000000..77929f60f73dc07ea2a8e47de1cfb95b5f8859f0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/wmma_array.h @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types + and is safe to use in a union. +*/ + +#pragma once + +#include "cutlass/arch/wmma.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Wmma array type (WmmaFragmentArray holds elements of type nvcuda::wmma::fragment) +template < + /// Element type + typename T, + /// Number of elements in the array + int N, + /// Whether the element type of T is half_t or __half + bool IsHalfType = (platform::is_same::value || + platform::is_same::value) +> +class WmmaFragmentArray: public Array { +public: + + /// Efficient clear method (override Array::clear()) + CUTLASS_HOST_DEVICE + void clear() + { + for(int i = 0; i < Array::kElements; i++) + { + nvcuda::wmma::fill_fragment((*this)[i], (typename T::element_type)0); + } + } + + CUTLASS_HOST_DEVICE + WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) + { + using element_type = typename T::element_type; + plus add; + + for (int i = 0; i < Array::kElements; i++) + { + (*this)[i] = add((*this)[i], rhs[i]); + } + + return *this; + } +}; + +/// Partial specialization for the case in which T::element_type is +/// half_t or __half. This is needed because the cast (typename T::element_type)0 +/// in the primary template flags as an error when __CUDA_NO_HALF_CONVERSIONS__ +/// is set. +template < + /// Element type + typename T, + /// Number of elements in the array + int N +> +class WmmaFragmentArray: public Array { +public: + + /// Efficient clear method (override Array::clear()) + CUTLASS_HOST_DEVICE + void clear() + { + for(int i = 0; i < Array::kElements; i++) + { + nvcuda::wmma::fill_fragment((*this)[i], __float2half(0.f)); + } + } + + CUTLASS_HOST_DEVICE + WmmaFragmentArray& operator+=(const WmmaFragmentArray& rhs) + { + using element_type = typename T::element_type; + plus add; + + for (int i = 0; i < Array::kElements; i++) + { + (*this)[i] = add((*this)[i], rhs[i]); + } + + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/workspace.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/workspace.h new file mode 100644 index 0000000000000000000000000000000000000000..485ebbe3ae27af7ddc05bc1e36f32b1a4ee65901 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/include/cutlass/workspace.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for initializing workspaces +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include "cuda.h" +#include "cuda_runtime.h" + +#include "cutlass/trace.h" +#endif + +#include "cutlass.h" +#include "cutlass/cuda_host_adapter.hpp" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int MinWorkspaceAlignment = 16; + +#if !defined(__CUDACC_RTC__) +static Status +zero_workspace( + void* workspace, + size_t workspace_size, + cudaStream_t stream = nullptr, + [[maybe_unused]] CudaHostAdapter *cuda_adapter = nullptr) { + if (workspace_size > 0) { + if (workspace == nullptr) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + CUTLASS_TRACE_HOST(" clearing workspace"); + +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { + return Status::kErrorInternal; + } + } + else { + return Status::kErrorInternal; + } +#else + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_size, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } +#endif + } + + return Status::kSuccess; +} +#endif + +#if !defined(__CUDACC_RTC__) +template +Status +fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr) { + static_assert(sizeof(T) == 4 || sizeof(T) == 2 || sizeof(T) == 1, "Unsupported fill type"); + if (fill_count > 0) { + if (workspace == nullptr) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + CUTLASS_TRACE_HOST(" filling workspace"); + +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, fill_value, fill_count, stream)) { + return Status::kErrorInternal; + } + } + else { + return Status::kErrorInternal; + } +#else + CUdeviceptr d_workspace = reinterpret_cast(workspace); + CUresult result = CUDA_SUCCESS; + if (sizeof(T) == 4) { + result = cuMemsetD32Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + else if (sizeof(T) == 2) { + result = cuMemsetD16Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + else if (sizeof(T) == 1) { + result = cuMemsetD8Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + + if (CUDA_SUCCESS != result) { + const char** error_string_ptr = nullptr; + (void) cuGetErrorString(result, error_string_ptr); + if (error_string_ptr != nullptr) { + CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned error " << *error_string_ptr); + } + else { + CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned unrecognized error"); + } + return Status::kErrorInternal; + } +#endif + } + + return Status::kSuccess; +} +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb617dc20d35f6dd352a84c3964a58fa9bc687e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +# Local module imports +from .dsl import * +from .runtime import * +from ._mlir_helpers import lru_cache_ir +from .env_manager import get_str_env_var, detect_gpu_arch + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..607a24d032c6ef899b586a41d2bb771c381406b0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR Dialect helper functions +""" + +from . import arith +from .lru_cache_ir import lru_cache_ir + + +__all__ = ["arith", "lru_cache_ir"] + +try: + from . import gpu + + __all__.extend(["gpu"]) +except ImportError: + pass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py new file mode 100644 index 0000000000000000000000000000000000000000..60cc8db31fd7369d721f3d7c64c5bb8fb03502a8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/arith.py @@ -0,0 +1,691 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR Arith Dialect helper functions +""" + +import array +import numpy as np + +from ..common import * +from ..._mlir import ir # type: ignore +from ..._mlir.extras import types as T # type: ignore +from ..._mlir.dialects import arith, nvgpu, math, builtin # type: ignore + +from .lru_cache_ir import lru_cache_ir + +# ============================================================================= +# Arith Dialect Helper functions +# ============================================================================= + + +def recast_type(src_type, res_elem_type) -> ir.Type: + if isinstance(src_type, T.VectorType): + if src_type.scalable: + res_type = T.vector( + *src_type.shape, + res_elem_type, + scalable=src_type.scalable, + scalable_dims=src_type.scalable_dims, + ) + else: + res_type = T.vector(*src_type.shape, res_elem_type) + elif isinstance(src_type, T.RankedTensorType): + res_type = T.RankedTensorType.get( + element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides + ) + elif isinstance(src_type, T.UnrankedTensorType): + res_type = T.UnrankedTensorType.get(element_type=res_elem_type) + elif isinstance(src_type, T.MemRefType): + res_type = T.MemRefType.get( + element_type=res_elem_type, shape=src_type.shape, strides=src_type.strides + ) + else: + res_type = res_elem_type + return res_type + + +def is_scalar(ty) -> bool: + return not isinstance( + ty, (T.VectorType, T.RankedTensorType, T.UnrankedTensorType, T.MemRefType) + ) + + +def element_type(ty) -> ir.Type: + if not is_scalar(ty): + return ty.element_type + else: + return ty + + +def is_narrow_precision(ty) -> bool: + narrow_types = { + T.f8E8M0FNU(), + T.f8E4M3FN(), + T.f8E4M3(), + T.f8E5M2(), + T.f8E4M3B11FNUZ(), + T.f4E2M1FN(), + T.f6E3M2FN(), + T.f6E2M3FN(), + } + return ty in narrow_types + + +def is_float_type(ty) -> bool: + return ( + arith._is_float_type(ty) + # TODO-upstream: prediction is not correct. Patch here and fix in upstream later + or is_narrow_precision(ty) + or ty in (T.bf16(), T.tf32()) + ) + + +def truncf_to_narrow(res_ty, src, loc, ip): + res_elem_ty = element_type(res_ty) + if res_elem_ty == T.f8E8M0FNU(): + rnd = nvgpu.RoundingMode.RP + else: + rnd = nvgpu.RoundingMode.RN + return nvgpu.cvt_fptrunc(res_ty, src, rnd=rnd, loc=loc, ip=ip) + + +def extf_from_narrow(res_ty, src, loc, ip): + src_elem_ty = element_type(src.type) + + # When source type is E8M0, temporary element type has to be bf16 + tmp_elem_ty = T.bf16() if src_elem_ty == T.f8E8M0FNU() else T.f16() + tmp_ty = recast_type(src.type, tmp_elem_ty) + + # narrow -> bf16/f16 -> target type + tmp = nvgpu.cvt_fpext(tmp_ty, src, loc=loc, ip=ip) + return arith.extf(res_ty, tmp, loc=loc, ip=ip) + + +def bitcast(src, res_elem_type, *, loc=None, ip=None): + res_type = recast_type(src.type, res_elem_type) + return arith.bitcast(res_type, src, loc=loc, ip=ip) + + +def cvtf(src, res_elem_type, *, loc=None, ip=None): + src_elem_type = element_type(src.type) + + if res_elem_type == src_elem_type: + return src + + res_type = recast_type(src.type, res_elem_type) + + # Treat TF32 as F32 and use i32 as intermediate data + # TODO-upstream: update arith to support tf32 <-> f32 conversion + if src_elem_type == T.tf32(): + # tf32 -> i32 + tmp_type = recast_type(src.type, T.i32()) + src = builtin.unrealized_conversion_cast([tmp_type], [src], loc=loc, ip=ip) + # i32 -> f32 + src = bitcast(src, T.f32(), loc=loc, ip=ip) + # f32 -> X with `cvtf` recursively + return cvtf(src, res_elem_type, loc=loc, ip=ip) + + if res_elem_type == T.tf32(): + # X -> f32 with `cvtf`` recursively + tmp = cvtf(src, T.f32(), loc=loc, ip=ip) + # f32 -> i32 + tmp = bitcast(tmp, T.i32(), loc=loc, ip=ip) + # i32 -> tf32 + return builtin.unrealized_conversion_cast([res_type], [tmp], loc=loc, ip=ip) + + if res_elem_type.width > src_elem_type.width: + if is_narrow_precision(src_elem_type): + return extf_from_narrow(res_type, src, loc, ip) + else: + return arith.extf(res_type, src, loc=loc, ip=ip) + else: + tmp_mlir_type = recast_type(src.type, T.f32()) + + # f16 -- extf -> f32 -- truncf -> bf16 + # TODO-upstream: update arith to support bf16 <-> f16 conversion? + if (src_elem_type == T.f16() and res_elem_type == T.bf16()) or ( + src_elem_type == T.bf16() and res_elem_type == T.f16() + ): + tmp = arith.extf(tmp_mlir_type, src, loc=loc, ip=ip) + return arith.truncf(res_type, tmp, loc=loc, ip=ip) + + # {f8, f6, f4} -> f16, f32, ... + elif is_narrow_precision(res_elem_type): + return truncf_to_narrow(res_type, src, loc, ip) + else: + return arith.truncf(res_type, src, loc=loc, ip=ip) + + +def fptoi(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): + res_type = recast_type(src.type, res_elem_type) + # TODO-upstream: update arith to support this kind of conversion + if element_type(src.type) in (T.tf32(), T.bf16()): + src = cvtf(src, T.f32(), loc=loc, ip=ip) + + if signed: + return arith.fptosi(res_type, src, loc=loc, ip=ip) + else: + return arith.fptoui(res_type, src, loc=loc, ip=ip) + + +def itofp(src, signed: Union[bool, None], res_elem_type, *, loc=None, ip=None): + res_type = recast_type(src.type, res_elem_type) + + orig_res_type = res_type + # TODO-upstream: update arith to support this kind of conversion + if res_elem_type in (T.tf32(), T.bf16()): + res_type = recast_type(src.type, T.f32()) + + if signed and element_type(src.type).width > 1: + res = arith.sitofp(res_type, src, loc=loc, ip=ip) + else: + res = arith.uitofp(res_type, src, loc=loc, ip=ip) + + if orig_res_type == res_type: + return res + + return cvtf(res, element_type(orig_res_type), loc=loc, ip=ip) + + +def int_to_int(a, dst_elem_type, *, loc=None, ip=None): + src_signed = a.signed + dst_signed = dst_elem_type.signed + src_width = element_type(a.type).width + dst_width = dst_elem_type.width + + dst_mlir_type = recast_type(a.type, dst_elem_type.mlir_type) + + if dst_width == src_width: + return a + elif src_signed != False and not dst_signed: + # Signed -> Unsigned + if dst_width > src_width: + return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) + else: + return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) + elif src_signed == dst_signed: + # Same signedness + if dst_width > src_width: + if src_signed != False and src_width > 1: + return arith.extsi(dst_mlir_type, a, loc=loc, ip=ip) + else: + return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) + else: + return arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) + else: + # Unsigned -> Signed + if dst_width > src_width: + return arith.extui(dst_mlir_type, a, loc=loc, ip=ip) + else: + # For truncation from unsigned to signed, we need to handle overflow + # First truncate to the target width + trunc = arith.trunci(dst_mlir_type, a, loc=loc, ip=ip) + # Then reinterpret as signed + if dst_signed: + return arith.bitcast(dst_mlir_type, trunc, loc=loc, ip=ip) + return trunc + + +# ============================================================================= +# Arith Ops Emitter Helpers +# - assuming type of lhs and rhs match each other +# - op name matches python module operator +# ============================================================================= + + +def _cast(res_elem_ty, src, is_signed=None, *, loc=None, ip=None): + """ + This function provides simplified interface to upstream op builder + arith.truncf(T.vector(shape, new_type), src) + + is simplified as because it's element-wise op which can't change shape + arith.truncf(new_type, src) + """ + if isinstance(src, ir.Value): + src_ty = src.type + else: + src_ty = type(src).mlir_type + src = src.ir_value() + + src_elem_ty = element_type(src_ty) + + if src_elem_ty == res_elem_ty: + return src + elif is_float_type(src_elem_ty) and is_float_type(res_elem_ty): + # float-to-float + return cvtf(src, res_elem_ty, loc=loc, ip=ip) + elif arith._is_integer_like_type(src_elem_ty) and arith._is_integer_like_type( + res_elem_ty + ): + if src_elem_ty.width >= res_elem_ty.width: + cast_op = arith.trunci + else: + if is_signed: + cast_op = arith.extsi + else: + cast_op = arith.extui + + res_ty = recast_type(src_ty, res_elem_ty) + return cast_op(res_ty, src, loc=loc, ip=ip) + elif is_float_type(src_elem_ty) and arith._is_integer_like_type(res_elem_ty): + return fptoi(src, is_signed, res_elem_ty, loc=loc, ip=ip) + elif arith._is_integer_like_type(src_elem_ty) and is_float_type(res_elem_ty): + return itofp(src, is_signed, res_elem_ty, loc=loc, ip=ip) + else: + raise DSLRuntimeError( + f"cast from {src_elem_ty} to {res_elem_ty} is not supported" + ) + + +@lru_cache_ir() +def const(value, ty=None, *, loc=None, ip=None): + """ + Generates dynamic expression for constant values. + """ + from ..typing import Numeric, NumericMeta + from ..dsl import is_dynamic_expression, _numpy_type_to_mlir_type + + if isinstance(value, Numeric): + value = value.value + + # Early return + if is_dynamic_expression(value) and ( + value.type.isinstance(value.type) or T.bool().isinstance(value.type) + ): + return value + + # Assume type + if ty is None: + if isinstance(value, float): + ty = T.f32() + elif isinstance(value, bool): + ty = T.bool() + elif isinstance(value, int): + ty = T.i32() + elif isinstance(value, np.ndarray): + ty = T.vector(*value.shape, _numpy_type_to_mlir_type(value.dtype)) + value = array.array(value.dtype.kind, value.flatten().tolist()) + else: + raise DSLNotImplemented(f"{type(value)} is not supported") + elif isinstance(ty, NumericMeta): + ty = ty.mlir_type + elif isinstance(ty, ir.Type): + if ir.RankedTensorType.isinstance(ty) or ir.VectorType.isinstance(ty): + elem_ty = ty.element_type + if isinstance(elem_ty, ir.IntegerType): + attr = ir.IntegerAttr.get(elem_ty, value) + else: + attr = ir.FloatAttr.get(elem_ty, value) + value = ir.DenseElementsAttr.get_splat(ty, attr) + elif arith._is_float_type(ty) and isinstance(value, (bool, int)): + value = float(value) + elif arith._is_integer_like_type(ty) and isinstance(value, float): + value = int(value) + else: + raise DSLNotImplemented(f"type {ty} is not supported") + + return arith.constant(ty, value, loc=loc, ip=ip) + + +def _dispatch_to_rhs_r_op(op): + """Decorator that dispatches to the right-hand-side's reverse operation. + + If the other operand is not an ArithValue or is a subclass (more specific) + of ArithValue, this allows proper method resolution for binary operations. + """ + + def wrapper(self, other, **kwargs): + if not isinstance(other, ArithValue): + if not isinstance(other, (int, float, bool)): + # allows to call other.__rmul__ + return NotImplemented + + return op(self, other, **kwargs) + + return wrapper + + +def _binary_op(op): + """ + Decorator to check if the 'other' argument is an ArithValue. + If not, returns NotImplemented. + """ + + def wrapper(self, other, **kwargs): + # When reach this point, `self` must be cast to base `ArithValue` type + if isinstance(other, (int, float, bool)): + other = const(other, self.type).with_signedness(self.signed) + + # Call the original function + # If sub-class doesn't implement overloaded arithmetic, cast to base class + return op(self, other, **kwargs) + + return wrapper + + +# Operator overloading +@ir.register_value_caster(ir.Float4E2M1FNType.static_typeid) +@ir.register_value_caster(ir.Float6E2M3FNType.static_typeid) +@ir.register_value_caster(ir.Float6E3M2FNType.static_typeid) +@ir.register_value_caster(ir.Float8E4M3FNType.static_typeid) +@ir.register_value_caster(ir.Float8E4M3B11FNUZType.static_typeid) +@ir.register_value_caster(ir.Float8E5M2Type.static_typeid) +@ir.register_value_caster(ir.Float8E4M3Type.static_typeid) +@ir.register_value_caster(ir.Float8E8M0FNUType.static_typeid) +@ir.register_value_caster(ir.BF16Type.static_typeid) +@ir.register_value_caster(ir.F16Type.static_typeid) +@ir.register_value_caster(ir.FloatTF32Type.static_typeid) +@ir.register_value_caster(ir.F32Type.static_typeid) +@ir.register_value_caster(ir.F64Type.static_typeid) +@ir.register_value_caster(ir.IntegerType.static_typeid) +@ir.register_value_caster(ir.VectorType.static_typeid) +@ir.register_value_caster(ir.RankedTensorType.static_typeid) +class ArithValue(ir.Value): + """Overloads operators for MLIR's Arith dialects binary operations.""" + + def __init__(self, v, signed: Union[bool, None] = None): + if isinstance(v, int): + v = arith.constant(self.type, v) + super().__init__(v) + + elem_ty = element_type(self.type) + self.is_float = arith._is_float_type(elem_ty) + # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL + self.signed = signed and elem_ty.width > 1 + + def with_signedness(self, signed: Union[bool, None]): + return type(self)(self, signed) + + def __neg__(self, *, loc=None, ip=None): + if self.type == T.bool(): + raise TypeError( + "Negation, the operator `-` is not supported for boolean type" + ) + + if self.is_float: + return arith.negf(self, loc=loc, ip=ip) + else: + c0 = arith.constant(self.type, 0, loc=loc, ip=ip) + return arith.subi(c0, self, loc=loc, ip=ip) + + @_binary_op + def __pow__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float and other.is_float: + return math.powf(self, other, loc=loc, ip=ip) + elif self.is_float and not other.is_float: + return math.fpowi(self, other, loc=loc, ip=ip) + elif not self.is_float and other.is_float: + lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip) + rhs = cvtf(other, T.f32(), loc=loc, ip=ip) + return math.powf(lhs, rhs, loc=loc, ip=ip) + elif not self.is_float and not other.is_float: + return math.ipowi(self, other, loc=loc, ip=ip) + else: + raise DSLNotImplemented(f"Unsupported '{self} ** {other}'") + + @_binary_op + def __rpow__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__pow__(self, loc=loc, ip=ip) + + # arith operators + + @_dispatch_to_rhs_r_op + @_binary_op + def __add__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.addf(self, other, loc=loc, ip=ip) + else: + return arith.addi(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __sub__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.subf(self, other, loc=loc, ip=ip) + else: + return arith.subi(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __mul__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.mulf(self, other, loc=loc, ip=ip) + else: + return arith.muli(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __truediv__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.divf(self, other, loc=loc, ip=ip) + else: + lhs = itofp(self, self.signed, T.f32(), loc=loc, ip=ip) + rhs = itofp(other, other.signed, T.f32(), loc=loc, ip=ip) + return arith.divf(lhs, rhs, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __floordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + q = arith.divf(self, other, loc=loc, ip=ip) + return math.floor(q, loc=loc, ip=ip) + elif self.signed != False: + return arith.floordivsi(self, other, loc=loc, ip=ip) + else: + return arith.divui(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __mod__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.remf(self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.remsi(self, other, loc=loc, ip=ip) + else: + return arith.remui(self, other, loc=loc, ip=ip) + + @_binary_op + def __radd__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__add__(self, loc=loc, ip=ip) + + @_binary_op + def __rsub__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__sub__(self, loc=loc, ip=ip) + + @_binary_op + def __rmul__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__mul__(self, loc=loc, ip=ip) + + @_binary_op + def __rtruediv__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__truediv__(self, loc=loc, ip=ip) + + @_binary_op + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__floordiv__(self, loc=loc, ip=ip) + + @_binary_op + def __rmod__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__mod__(self, loc=loc, ip=ip) + + # Comparison operators (comparison doesn't have right-hand-side variants) + @_dispatch_to_rhs_r_op + @_binary_op + def __lt__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OLT, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.slt, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ult, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __le__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OLE, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.sle, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ule, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __eq__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OEQ, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.eq, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __ne__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + # In Python, bool(float("nan")) is True, so use unordered comparison here + return arith.cmpf(arith.CmpFPredicate.UNE, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ne, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __gt__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OGT, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.sgt, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.ugt, self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __ge__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.is_float: + return arith.cmpf(arith.CmpFPredicate.OGE, self, other, loc=loc, ip=ip) + elif self.signed != False: + return arith.cmpi(arith.CmpIPredicate.sge, self, other, loc=loc, ip=ip) + else: + return arith.cmpi(arith.CmpIPredicate.uge, self, other, loc=loc, ip=ip) + + # Unary operators + def __invert__(self, *, loc=None, ip=None) -> "ArithValue": + return arith.xori(self, arith.constant(self.type, -1)) + + # Bitwise operations + @_dispatch_to_rhs_r_op + @_binary_op + def __and__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.andi(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __or__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.ori(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __xor__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.xori(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __rshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + if self.signed != False: + return arith.shrsi(self, other, loc=loc, ip=ip) + else: + return arith.shrui(self, other, loc=loc, ip=ip) + + @_dispatch_to_rhs_r_op + @_binary_op + def __lshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.shli(self, other, loc=loc, ip=ip) + + @_binary_op + def __rand__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.andi(other, self, loc=loc, ip=ip) + + @_binary_op + def __ror__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.ori(other, self, loc=loc, ip=ip) + + @_binary_op + def __rxor__(self, other, *, loc=None, ip=None) -> "ArithValue": + return arith.xori(other, self, loc=loc, ip=ip) + + @_binary_op + def __rrshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__rshift__(self, loc=loc, ip=ip) + + @_binary_op + def __rlshift__(self, other, *, loc=None, ip=None) -> "ArithValue": + return other.__lshift__(self, loc=loc, ip=ip) + + def __hash__(self): + return super().__hash__() + + def __str__(self): + return "?" + + def __repr__(self): + return self.__str__() + + +def _min(lhs, rhs, *, loc=None, ip=None): + """ + This function provides a unified interface for building arith min + + Assuming the operands have the same type + """ + from ..dsl import is_dynamic_expression + + if not is_dynamic_expression(lhs): + if not is_dynamic_expression(rhs): + return min(lhs, rhs) + else: + lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip) + else: + if not is_dynamic_expression(rhs): + rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) + + if arith._is_integer_like_type(lhs.type): + if lhs.signed != False: + return arith.minsi(lhs, rhs, loc=loc, ip=ip) + else: + return arith.minui(lhs, rhs, loc=loc, ip=ip) + else: + return arith.minimumf(lhs, rhs, loc=loc, ip=ip) + + +def _max(lhs, rhs, *, loc=None, ip=None): + """ + This function provides a unified interface for building arith max + + Assuming the operands have the same type + """ + from ..dsl import is_dynamic_expression + + if not is_dynamic_expression(lhs): + if not is_dynamic_expression(rhs): + return max(lhs, rhs) + else: + lhs = arith.constant(rhs.type, lhs, loc=loc, ip=ip) + else: + if not is_dynamic_expression(rhs): + rhs = arith.constant(lhs.type, rhs, loc=loc, ip=ip) + + if arith._is_integer_like_type(lhs.type): + if lhs.signed != False: + return arith.maxsi(lhs, rhs, loc=loc, ip=ip) + else: + return arith.maxui(lhs, rhs, loc=loc, ip=ip) + else: + return arith.maximumf(lhs, rhs, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b0d0500824f3c5ffc9ae51c7218f40c64b780c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/gpu.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR GPU Dialect helper functions +""" + + +from ..._mlir import ir +from ..._mlir.dialects import gpu, arith, scf +from ..._mlir.extras import types as T + +from ..common import * + +# ============================================================================= +# GPU Dialect Helper functions +# ============================================================================= + + +def create_async_token(): + token_ty = gpu.AsyncTokenType.get() + token = gpu.wait(token_ty, []) + return token + + +def printf(fmt, *args, threadNumber=-1): + """Generate gpu.printf OP predicated on threadNumber""" + type_formats = [] + for arg in args: + ty_format = None + if ir.IndexType.isinstance(arg.type): + ty_format = "%llu" + if ir.IntegerType.isinstance(arg.type): + width = ir.IntegerType(arg.type).width + if width == 64: + ty_format = "%llu" + elif width == 32: + ty_format = "%d" + elif width == 1: + ty_format = "%i" + if ir.F32Type.isinstance(arg.type): + ty_format = "%f" + if ty_format is None: + raise DSLNotImplemented(arg.type) + type_formats.append(ty_format) + if threadNumber == -1: + gpu.printf(fmt.format(*type_formats) + "\n", args) + if threadNumber != -1: + tidx = gpu.thread_id(gpu.Dimension.x) + predicate = arith.cmpi( + arith.CmpIPredicate.eq, tidx, arith.constant(_T.index(), threadNumber) + ) + if_op = scf.IfOp(predicate) + with ir.InsertionPoint(if_op.then_block): + gpu.printf(fmt.format(*type_formats) + "\n", args) + scf.yield_([]) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..57d717b42f94cfab678e70eceb5cc4d30dd10a45 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/lru_cache_ir.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides @lru_cache_ir +It extends functools.lru_cache with IR Context awareness. + +Example usage: +from cutlass import ir +from lru_cache_ir import lru_cache_ir + +@lru_cache_ir(ir, maxsize=128, typed=False) +def make_layout(...): +... + +""" + + +from functools import lru_cache, wraps + +from ..._mlir import ir # type: ignore + + +def get_ir_context(func): + """ + Return the context for given func called under ir. + Currently the context includes MLIRContext and InsertionPoint. + """ + try: + if ir: + return (ir.Context.current, ir.InsertionPoint.current) + else: + return None + except ValueError: + return None + + +def lru_cache_ir(maxsize=128, typed=True): + """ + Applies an LRU cache to a given function, with awareness of IR context. + + Usage is similar to functools.lru_cache while taking `ir` as required argument. + + :param ir: The IR object from which to derive the context by `get_ir_context` + :param maxsize: Max cache size, same as functools.lru_cache + :param typed: Whether params are type-sensitive, default to True as IR is type-sensitive + """ + + def decorator(func): + # Use functools.lru_cache with a custom wrapper to control the key generation + @lru_cache(maxsize=maxsize, typed=typed) + def cached_func(context, *args, **kwargs): + return func(*args, **kwargs) + + @wraps(func) + def wrapper(*args, **kwargs): + try: + # Call the cached function with the context + return cached_func(get_ir_context(func), *args, **kwargs) + except (RuntimeError, TypeError): + return func(*args, **kwargs) + + # Expose cache-related methods for introspection + wrapper.cache_clear = cached_func.cache_clear + wrapper.cache_info = cached_func.cache_info + return wrapper + + return decorator diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py new file mode 100644 index 0000000000000000000000000000000000000000..3989c75e5462d11d5ca229b757f4e5b45c7ee013 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/_mlir_helpers/op.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides MLIR's OP helper functions +""" + + +import inspect +from functools import wraps + +from ..._mlir import ir + + +def dsl_user_op(opFunc): + @wraps(opFunc) + def wrapper(*args, **kwargs): + loc = kwargs.pop("loc", None) + if loc is None: + frame = inspect.currentframe().f_back + file_loc = ir.Location.file(frame.f_code.co_filename, frame.f_lineno, 0) + loc = ir.Location.name(frame.f_code.co_name, childLoc=file_loc) + res_or_list = opFunc(*args, **kwargs, loc=loc) + return res_or_list + + return wrapper diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7b11474c6b5b4fd30fb1feb6fae792fc9e059686 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_helpers.py @@ -0,0 +1,616 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides helper functions that are generated by the preprocessor. +The preprocessor read through python's ast and changes the input code. +""" + +from typing import Callable, Iterator, Optional, overload +from typing_extensions import deprecated +import warnings +import inspect +from types import BuiltinFunctionType +from functools import lru_cache +from inspect import getmembers + +from .utils.logger import log +from .common import * + +from ._mlir_helpers.arith import ArithValue + + +class Executor: + """ + The Executor class handles dynamic and compile-time (constexpr) execution + of "for" loops and "if-else-elif" statements. + + Methods: + set_functions: Assigns the functions for checking loop bounds and + conditional evaluation. + + for_execute: Generates MLIR for OP + while_execute: Generates MLIR while OP + if_execute: generate MLIR if OP + """ + + def __init__(self): + self._is_dynamic_expression = None + self._loop_execute_range_dynamic = None + self._if_dynamic = None + self._while_dynamic = None + self._compare_executor = None + self._any_executor = None + self._all_executor = None + self._builtin_redirector = None + + def set_functions( + self, + *, + is_dynamic_expression: Callable, + loop_execute_range_dynamic: Callable, + if_dynamic: Callable, + while_dynamic: Callable, + compare_executor: Callable, + any_executor: Callable = None, + all_executor: Callable = None, + builtin_redirector: Callable = None, + ): + self._is_dynamic_expression = is_dynamic_expression + self._loop_execute_range_dynamic = loop_execute_range_dynamic + self._if_dynamic = if_dynamic + self._while_dynamic = while_dynamic + self._compare_executor = compare_executor + self._any_executor = any_executor + self._all_executor = all_executor + self._builtin_redirector = builtin_redirector + + @staticmethod + def convert_to_list(x): + """This function is used to convert x to a list. + If x is None, return an empty list. + If x is not a list, return a list containing x. + Otherwise, return x itself. + """ + if x is None: + return [] + if not isinstance(x, list): + return [x] + return x + + @staticmethod + def converge_ret_val(res): + """This function is used to converge res (the return value) of the function. + If res is None, return None. + If res is a list and has only one element, return the element. + Otherwise, return res itself. + """ + if res is None: + return res + elif isinstance(res, list) and len(res) == 1: + return res[0] + return res + + def for_execute( + self, + func, + start, + stop, + step, + write_args=[], + full_write_args_count=0, + write_args_names=[], + unroll=-1, + unroll_full=False, + prefetch_stages=None, + ): + assert ( + self._loop_execute_range_dynamic + ), "Functions must be set before execution." + log().debug("start [%s] stop [%s] step [%s]", start, stop, step) + + return self._loop_execute_range_dynamic( + func, + start, + stop, + step, + write_args, + full_write_args_count, + write_args_names, + unroll, + unroll_full, + prefetch_stages, + ) + + def if_execute( + self, + pred, + then_block: Callable, + else_block: Optional[Callable] = None, + write_args=[], + full_write_args_count=0, + write_args_names=[], + ): + assert self._if_dynamic, "Functions must be set before execution." + + # MLIR generation + return self._if_dynamic( + pred, + then_block, + else_block, + write_args, + full_write_args_count, + write_args_names, + ) + + def while_execute( + self, + pred, + while_before_block: Callable, + while_after_block: Callable, + write_args=[], + full_write_args_count=0, + write_args_names=[], + ): + assert self._while_dynamic, "Functions must be set before execution." + + # MLIR generation + return self._while_dynamic( + while_before_block, + while_after_block, + write_args, + full_write_args_count, + write_args_names, + ) + + +# ============================================================================= +# Decorator +# ============================================================================= + +executor = Executor() + + +def loop_selector( + start, + stop, + step, + *, + write_args=[], + full_write_args_count=0, + write_args_names=[], + unroll=-1, + unroll_full=False, + prefetch_stages=None, +): + log().debug( + "start [%s] stop [%s] step [%s] write_args [%s] full_write_args_count [%s] write_args_names [%s] unroll [%s] unroll_full [%s] prefetch_stages [%s]", + start, + stop, + step, + write_args, + full_write_args_count, + write_args_names, + unroll, + unroll_full, + prefetch_stages, + ) + from .typing import Integer, Numeric + + def _maybe_upcast(value): + if isinstance(value, Integer): + value = value.ir_value() + + return value + + start = _maybe_upcast(start) + stop = _maybe_upcast(stop) + step = _maybe_upcast(step) + + def ir_loop(func): + return executor.for_execute( + func, + start, + stop, + step, + write_args, + full_write_args_count, + write_args_names, + unroll, + unroll_full, + prefetch_stages, + ) + + return ir_loop + + +def if_selector(pred, write_args=[]): + log().debug("pred [%s] write_args [%s]", pred, write_args) + # Handle Numeric types here? + + from .typing import Numeric + + if isinstance(pred, Numeric): + pred = pred.value + + def ir_loop(func): + return func(pred, *write_args) + + return ir_loop + + +def while_selector(pred, write_args=[]): + def ir_while_loop(func): + return func(pred, *write_args) + + return ir_while_loop + + +def while_executor( + pred, + while_before_block: Callable, + while_after_block: Callable, + write_args=[], + full_write_args_count=0, + write_args_names=[], +): + return executor.while_execute( + pred, + while_before_block, + while_after_block, + write_args, + full_write_args_count, + write_args_names, + ) + + +def if_executor( + pred, + then_block: Callable, + else_block: Optional[Callable] = None, + write_args=[], + full_write_args_count=0, + write_args_names=[], +): + return executor.if_execute( + pred, + then_block, + else_block, + write_args, + full_write_args_count, + write_args_names, + ) + + +# ============================================================================= +# Range +# ============================================================================= + + +class range: + """ + A range-like object for dynamic loop iteration in the DSL. + + This class provides a range interface similar to Python's built-in range, + but is designed to be preprocessed into constructs for dynamic + loop execution. + + The class supports both single-argument (stop) and three-argument + (start, stop, step) constructors with additional parameters for loop + optimization: + + - unroll: Number of iterations to unroll (0 or 1 = no unrolling) + - unroll_full: Whether to fully unroll the loop + - prefetch_stages: Number of prefetch stages to generate + """ + + @overload + def __new__(cls, stop, unroll=0, unroll_full=False, prefetch_stages=None): + pass + + @overload + def __new__( + cls, start, stop, step, unroll=0, unroll_full=False, prefetch_stages=None + ): + pass + + def __new__(cls, *args, **kwargs): + raise DSLRuntimeError("dynamic range should be always preprocessed to IR") + + def __iter__(self) -> Iterator[int]: + raise DSLRuntimeError("dynamic range should be always preprocessed to IR") + + +@deprecated( + "range_dynamic is deprecated and will be removed in the future, please remove it." +) +def range_dynamic(*args, **kwargs): + raise DSLRuntimeError("range_dynamic should be always preprocessed to IR") + + +def range_constexpr(*args): + raise DSLRuntimeError("range_constexpr should be preprocessed by preprocessor.") + + +# ============================================================================= +# If expressions +# ============================================================================= + + +def const_expr(expression): + """ + This function is used to check if the expression is a python value. + If the expression is a python value, return the boolean value of the expression. + If the expression is a dynamic expression, raise an error. + """ + from .typing import Numeric + + failed = False + + if isinstance(expression, Numeric): + if isinstance(expression.value, (int, float, bool)): + return expression.value + else: + failed = True + elif executor._is_dynamic_expression(expression): + failed = True + + if failed: + raise DSLRuntimeError( + f"The function `const_expr({expression})` received a dynamic expression (non compile-time constant).", + context={ + "If your expression depends on dynamic values": "Remove `const_expr()`", + }, + ) + return expression + + +@deprecated( + "dynamic_expr is deprecated and will be removed in the future, please remove it." +) +def dynamic_expr(expression): + return expression + + +# ============================================================================= +# Assertion & casting +# ============================================================================= + + +def assert_executor(test, msg=None): + from .typing import Numeric + + fail = False + # Implicit convert dynamic expression to bool is not allowed + # So here explicitly do a None check + if test is not None and executor._is_dynamic_expression(test): + if isinstance(test, Numeric): + try: + test = test.to(bool) + except: + fail = True + else: + fail = True + + if not fail: + assert test, msg + else: + raise DSLRuntimeError( + "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", + suggestion="Please replace with runtime assert.", + ) + + +def bool_cast(value): + if executor._is_dynamic_expression(value): + raise DSLRuntimeError( + "Only constexpr (Python Value) is allowed here, but got non-constexpr (IR Values) expression.", + suggestion="Please explicitly convert to boolean with expressions like comparision.", + ) + return bool(value) + + +def compare_executor(left, comparators, ops): + """ + Executes comparison operations with a left operand and a list of comparators. + + Args: + left: The leftmost value in the comparison chain + comparators: A list of values to compare against + ops: A list of comparison operators to apply + + Returns: + The result of the comparison chain + + Raises: + AssertionError: If the executor function is not set before execution + """ + assert ( + executor._compare_executor is not None + ), "Function must be set before execution." + return executor._compare_executor(left, comparators, ops) + + +def any_executor(iterable): + """Executes the 'any' operation on an iterable, handling both dynamic and static expressions. + + :param iterable: An iterable to check if any elements evaluate to True + :type iterable: Iterable + :return: boolean of Python value or IR value + :rtype: bool or cutlass.Boolean + + """ + if executor._any_executor and executor._is_dynamic_expression(iterable): + return executor._any_executor(iterable) + else: + return any(iterable) + + +def all_executor(iterable): + """Executes the 'all' operation on an iterable, handling both dynamic and static expressions. + + :param iterable: An iterable to check if all elements evaluate to True + :type iterable: Iterable + :return: boolean of Python value or IR value + :rtype: bool or cutlass.Boolean + """ + if executor._all_executor and executor._is_dynamic_expression(iterable): + return executor._all_executor(iterable) + else: + return all(iterable) + + +# ============================================================================= +# Control flow checks +# ============================================================================= +class DSLOptimizationWarning(Warning): + """ + This warning is used to warn the user about the optimization related issues in DSL. + """ + + def __init__(self, message): + self.message = message + super().__init__() + + def __str__(self): + return self.message + + +def range_value_check(*args): + """ + Ensure all `range_constexpr` bounds are compile-time constants (Python ints). + """ + try: + args = tuple(arg.__index__() for arg in args) + + # Compute range size and warn if it's too large + start = 0 + end = 0 + step = 1 + if len(args) == 1: + end = args[0] + elif len(args) == 2: + start = args[0] + end = args[1] + elif len(args) == 3: + start = args[0] + end = args[1] + step = args[2] + + range_length = (abs(end - start) - 1) // abs(step) + 1 + if range_length >= 64: + warnings.warn( + f"This static loop has {range_length} iterations, which may be very slow to compile, consider using `cutlass.range(..., unroll_full=True)` instead.", + category=DSLOptimizationWarning, + stacklevel=2, + ) + + return (start, end, step) + except: + raise DSLRuntimeError( + "`range_constexpr` requires constexpr (compile-time constant) for all arguments.", + suggestion="Use `range` instead of `range_constexpr`.", + ) + + +def range_perf_warning(filename, lineno, *args): + has_dynamic_expr = False + for arg in args: + if executor._is_dynamic_expression(arg): + has_dynamic_expr = True + break + if not has_dynamic_expr: + warnings.warn_explicit( + ( + "This loop is no longer unrolled and may cause performance regression. " + "Use `range(..., unroll_full=True)` for full unrolling, or switch to `range_constexpr` when bounds are compile-time constants." + ), + category=DSLOptimizationWarning, + filename=filename, + lineno=lineno, + ) + + +@lru_cache(maxsize=1) +def _get_self_module(): + """ + This function is used to get the owning module of this function. + """ + return inspect.getmodule(_get_self_module) + + +def cf_symbol_check(symbol): + """ + Check if the symbol is control flow symbol from current module. + """ + + failed = False + name = symbol.__name__ + self_module = _get_self_module() + if inspect.ismodule(symbol): + name = "range" + if not self_module.__name__.startswith(symbol.__name__): + failed = True + else: + owning_module = inspect.getmodule(symbol) + if owning_module != self_module: + failed = True + + if failed: + raise DSLRuntimeError( + f"Incorrect {symbol.__name__} is used.", + suggestion=f"Please avoid overriding `{symbol.__name__}` from DSL package.", + ) + + +def redirect_builtin_function(fcn): + """ + This function is used to redirect built-in function call + to the function defined in DSL package. + """ + # Only redirect if it's a built-in + if isinstance(fcn, BuiltinFunctionType) and executor._builtin_redirector: + return executor._builtin_redirector(fcn) + return fcn + + +def copy_members(dest, src): + """ + Copies all non-callable, non-dunder members from src to dest if they exist in src. + Skips members that are callables or have names starting with double underscores. + """ + if id(dest) == id(src): + return + + members = getmembers(dest) + for name, value in members: + if ( + name.startswith("__") + or isinstance(value, Callable) + or not hasattr(src, name) + ): + continue + setattr(dest, name, getattr(src, name)) + + +def get_locals_or_none(locals, symbols): + """ + Given a locals() dictionary and a list of symbol names, return a list of their values + in the same order as the symbols list. If a symbol is not present in locals, None is returned + for that symbol. + """ + variables = [] + for symbol in symbols: + if symbol in locals: + variables.append(locals[symbol]) + else: + variables.append(None) + return variables diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..11f2d1ae84405a13f7fffd241c6e6bdd6e167010 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/ast_preprocessor.py @@ -0,0 +1,1958 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module defines the `DSLPreprocessor` class, which acts as a Python preprocessor. +It uses Python's AST and rewrites specific Python statements such as `for` and `if-else`. + +The preprocessor operates on the following constructs: + - `for` loops: + - Rewrites `for` loops with the `@loop_selector` decorator. + - Supports `range`, `range_dynamic` for loop iteration. + - `if-elif-else` statements: + - Rewrites conditional statements with the `@if_selector` decorator. + - Supports `dynamic_expr` and `const_expr` in the condition expressions. + +Additionally, both `for` loops and `if-else` statements require `yield` +operation generation. The preprocessor handles this by: + - Using a `ScopeManager` to track symbols across different scopes during AST traversal. + - Identifying read-only, read-write, and active variables for DSL constructs. + - Generating `yield` operations for symbols that are classified as read-write or write. + +It is designed to be generic and can handle `for` and `if` constructs from other dialects. +In such cases, the user's DSL should implement `@loop_selector` and `@if_selector` +to generate dialect-specific operations for `for` and `if` statements. +""" + +import ast +import importlib +import inspect +import textwrap +import warnings +from dataclasses import dataclass +from typing import List, Set, Dict, Any, Callable, Optional +from types import ModuleType +from collections import OrderedDict +from copy import deepcopy + +from .common import * +from .utils.logger import log + + +class OrderedSet: + """ + A deterministic set implementation for ordered operations. + """ + + def __init__(self, iterable=None): + self._dict = dict.fromkeys(iterable or []) + + def add(self, item): + self._dict[item] = None + + def __iter__(self): + return iter(self._dict) + + def __and__(self, other): + return OrderedSet(key for key in self._dict if key in other) + + def __or__(self, other): + new_dict = self._dict.copy() + new_dict.update(dict.fromkeys(other)) + return OrderedSet(new_dict) + + def __sub__(self, other): + return OrderedSet(key for key in self._dict if key not in other) + + def intersections(self, others): + """Compute the intersection of this set with multiple other sets. + + :param others: A list of sets to compute intersections with + :type others: List[Set[str]] + :return: A new ordered set containing elements that appear in this set + and at least one of the other sets + """ + result = OrderedSet() + for key in self._dict: + for other in reversed(others): + if key in other: + result.add(key) + break + return result + + +@dataclass +class ImportInfo: + """ + Information about an import expression. + """ + module_path: str + attr_name: Optional[str] + alias_name: str + + +@dataclass +class ScopeManager: + """ + Manages symbol scopes during AST traversal. + Manage nested scopes during transformations. + """ + + scopes: List[Set[str]] + + @classmethod + def create(cls) -> "ScopeManager": + return cls([]) + + def add_to_scope(self, name: str) -> None: + if name == "_": + return + self.scopes[-1].add(name) + + def get_active_symbols(self) -> List[Set[str]]: + return self.scopes.copy() + + def __enter__(self) -> "ScopeManager": + self.scopes.append(set()) + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.scopes.pop() + + +class DSLPreprocessor(ast.NodeTransformer): + """ + A preprocessor for transforming Python ASTs. It supports: + + - Rewriting `for` loops with the `@loop_selector` decorator. + - Rewriting `if-elif-else` statements with the `@if_selector` decorator. + - Generating `yield` operations for read-write or write symbols. + """ + + DECORATOR_FOR_STATEMENT = "loop_selector" + DECORATOR_IF_STATEMENT = "if_selector" + DECORATOR_WHILE_STATEMENT = "while_selector" + IF_EXECUTOR = "if_executor" + WHILE_EXECUTOR = "while_executor" + ASSERT_EXECUTOR = "assert_executor" + BOOL_CAST = "bool_cast" + IMPLICIT_DOWNCAST_NUMERIC_TYPE = "implicitDowncastNumericType" + SUPPORTED_FOR_RANGE_STATEMENTS = {"range", "range_dynamic", "range_constexpr"} + COMPARE_EXECUTOR = "compare_executor" + ANY_EXECUTOR = "any_executor" + ALL_EXECUTOR = "all_executor" + + def __init__(self, client_module_name): + super().__init__() + self.counter = 0 # Unique function names for multiple loops + self.scope_manager = ScopeManager.create() + self.processed_functions = set() + self.function_counter = 0 + self.function_name = "" + self.class_name = None + self.file_name = "" + self.function_depth = 0 + self.local_closures = set() + self.function_globals = None + self.client_module_name = client_module_name + self.import_top_module = False + + def _create_module_attribute( + self, + func_name, + *, + top_module_name="_dsl_", + submodule_name="ast_helpers", + lineno=None, + col_offset=None, + ): + # If we simply copy location from origin node, it contains a way to wide range, which cause location in traceback to be wrong. + def set_location(node, lineno, col_offset): + if lineno and col_offset: + node.lineno = lineno + node.end_lineno = lineno + node.col_offset = col_offset + node.end_col_offset = col_offset + + base = ast.Name(id=top_module_name, ctx=ast.Load()) + set_location(base, lineno, col_offset) + if submodule_name: + base = ast.Attribute(value=base, attr=submodule_name, ctx=ast.Load()) + set_location(base, lineno, col_offset) + node = ast.Attribute(value=base, attr=func_name, ctx=ast.Load()) + set_location(node, lineno, col_offset) + return node + + def _get_module_imports(self, decorated_func): + """Extract imports from the module containing the decorated function""" + imports = [] + + # Get the module containing the decorated function + if module := inspect.getmodule(decorated_func): + try: + # Get the module source code + source = inspect.getsource(module) + module_ast = ast.parse(source) + + # Extract imports from the full module + alias = lambda n: n.asname if n.asname else n.name + for node in ast.walk(module_ast): + if isinstance(node, ast.Import): + for name in node.names: + imports.append( + ImportInfo( + module_path=name.name, + attr_name=None, + alias_name=alias(name), + ) + ) + elif isinstance(node, ast.ImportFrom): + module_name = node.module + if node.level > 0: + # Handle relative imports + package_name = module.__package__.rsplit( + ".", node.level - 1 + )[0] + module_name = f"{package_name}.{module_name}" + for name in node.names: + imports.append( + ImportInfo( + module_path=module_name, + attr_name=name.name, + alias_name=alias(name), + ) + ) + except (IOError, TypeError): + pass + + return imports + + def exec(self, function_name, original_function, code_object, exec_globals): + # Get imports from the original module + module_imports = self._get_module_imports(original_function) + + # Import all required modules + for import_info in module_imports: + module_path, attr_name, alias_name = ( + import_info.module_path, + import_info.attr_name, + import_info.alias_name, + ) + try: + module = importlib.import_module(module_path) + if attr_name: + if attr_name == "*": + if hasattr(module, "__all__"): + attrs = module.__all__ + else: + attrs = [ + name for name in dir(module) if not name.startswith("_") + ] + else: + attrs = [attr_name] + + for attr in attrs: + alias = attr if attr_name == "*" else alias_name + exec_globals[alias] = getattr(module, attr) + else: + exec_globals[alias_name] = module + except (ImportError, AttributeError) as e: + raise ImportError(f"Failed to import {module_path}: {str(e)}") + + # Execute the transformed code + log().info( + "ASTPreprocessor Executing transformed code for function [%s]", + function_name, + ) + exec(code_object, exec_globals) + return exec_globals.get(function_name) + + @staticmethod + def print_ast(transformed_tree=None): + print("#", "-" * 40, "Transformed AST", "-" * 40) + unparsed_code = ast.unparse(transformed_tree) + print(unparsed_code) + print("#", "-" * 40, "End Transformed AST", "-" * 40) + + def make_func_param_name(self, base_name, used_names): + """Generate a unique parameter name that doesn't collide with existing names.""" + if base_name not in used_names: + return base_name + + i = 0 + while f"{base_name}_{i}" in used_names: + i += 1 + return f"{base_name}_{i}" + + def transform_function(self, func_name, function_pointer): + """ + Transforms a function. + """ + # Skip if the function has already been processed + if function_pointer in self.processed_functions: + log().info( + "ASTPreprocessor Skipping already processed function [%s]", func_name + ) + return [] + + # Step 1. Parse the given function + file_name = inspect.getsourcefile(function_pointer) + lines, start_line = inspect.getsourcelines(function_pointer) + dedented_source = textwrap.dedent("".join(lines)) + tree = ast.parse(dedented_source, filename=file_name) + # Bump the line numbers so they match the real source file + ast.increment_lineno(tree, start_line - 1) + + # Step 1.2 Check the decorator + if not self.check_decorator(tree.body[0]): + log().info( + "[%s] - Skipping function due to missing decorator", + func_name, + ) + return [] + + self.processed_functions.add(function_pointer) + log().info("ASTPreprocessor Transforming function [%s]", func_name) + + # Step 2. Transform the function + transformed_tree = self.visit(tree) + + # Step 3. Import cutlass and base_dsl + top_module_name = ".".join(self.client_module_name) + import_stmts = [] + if self.import_top_module: + import_stmts.append(ast.Import(names=[ast.alias(name=top_module_name)])) + import_stmts.append( + ast.Import( + names=[ast.alias(name=f"{top_module_name}.base_dsl", asname="_dsl_")] + ) + ) + transformed_tree.body = import_stmts + transformed_tree.body + + # Step 4. Import cutlass and base_dsl + ast.fix_missing_locations(transformed_tree) + combined_body = transformed_tree.body + + # Step 5. Return the transformed tree + return combined_body + + def check_early_exit(self, tree, kind): + """ + Checks if a given region or scope in the provided Python code has early exits. + """ + + class EarlyExitChecker(ast.NodeVisitor): + def __init__(self, kind): + self.has_early_exit = False + self.early_exit_node = None + self.early_exit_type = None + self.kind = kind + self.loop_nest_level = 0 + + # Early exit is not allowed in any level of dynamic control flow + def visit_Return(self, node): + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "return" + + def visit_Raise(self, node): + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "raise" + + def visit_Break(self, node): + # For break/continue in inner loops, we don't consider it as early exit + if self.loop_nest_level == 0 and self.kind != "if": + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "break" + + def visit_Continue(self, node): + if self.loop_nest_level == 0 and self.kind != "if": + self.has_early_exit = True + self.early_exit_node = node + self.early_exit_type = "continue" + + def visit_For(self, node): + self.loop_nest_level += 1 + self.generic_visit(node) + self.loop_nest_level -= 1 + + def visit_While(self, node): + self.loop_nest_level += 1 + self.generic_visit(node) + self.loop_nest_level -= 1 + + checker = EarlyExitChecker(kind) + checker.generic_visit(tree) + if not checker.has_early_exit: + return + raise DSLAstPreprocessorError( + message=f"Early exit ({checker.early_exit_type}) is not allowed in `{self.function_name}`" + + (f" in `{self.class_name}`" if self.class_name else ""), + filename=self.file_name, + snippet=ast.unparse(tree), + suggestion=( + "If predicates are constant expression, write like " + "`if const_expr(...)` or `for ... in range_constexpr(...)`. " + "In that case, early exit will be executed by Python " + "interpreter, so it's supported." + ), + ) + + def is_node_constexpr(self, node) -> bool: + """ + Determines if the node is a constexpr. + Supported nodes are if, while statements. + """ + if isinstance(node, ast.If) or isinstance(node, ast.While): + if isinstance(node.test, ast.Call): + func = node.test.func + + if isinstance(func, ast.Attribute) and func.attr == "const_expr": + return True + + elif isinstance(func, ast.Name) and func.id == "const_expr": + return True + return False + + def _get_range_kind(self, iter_node): + """ + Return "range", "range_dynamic", "range_constexpr" or None for the iterable + """ + if isinstance(iter_node, ast.Call): + func = iter_node.func + if ( + isinstance(func, ast.Name) + and func.id in self.SUPPORTED_FOR_RANGE_STATEMENTS + ): + return func.id, True, len(iter_node.keywords) != 0 + if ( + isinstance(func, ast.Attribute) + and func.attr in self.SUPPORTED_FOR_RANGE_STATEMENTS + ): + return func.attr, False, len(iter_node.keywords) != 0 + return None, None, None + + def transform(self, original_function, exec_globals): + """ + Transforms the provided function using the preprocessor. + """ + self.file_name = inspect.getsourcefile(original_function) + self.function_globals = exec_globals + transformed_tree = self.transform_function( + original_function.__name__, original_function + ) + self.function_globals = None + unified_tree = ast.Module(body=transformed_tree, type_ignores=[]) + unified_tree = ast.fix_missing_locations(unified_tree) + + return unified_tree + + def analyze_region_variables( + self, node: Union[ast.For, ast.If], active_symbols: List[Set[str]] + ): + """ + Analyze variables in different code regions to identify read-only, write-only, + and active variables for DSL constructs. + """ + + # we need orderedset to keep the insertion order the same. otherwise generated IR is different each time + write_args = OrderedSet() + invoked_args = OrderedSet() + local_closure = self.local_closures + file_name = self.file_name + region_node = node + + class RegionAnalyzer(ast.NodeVisitor): + force_store = False + + def visit_Name(self, node): + """ + Mark every store as write. + """ + if isinstance(node.ctx, ast.Store) or self.force_store: + write_args.add(node.id) + + def visit_Subscript(self, node): + # When subscript occurs on the lhs of an assignment, the `Name` is still a load, but `Subscript` is marked as `Store`. + # We need to force the store for the `Name` to be marked as write. + if isinstance(node.ctx, ast.Store): + self.force_store = True + self.visit(node.value) + self.force_store = False + self.visit(node.slice) + else: + self.generic_visit(node) + + def visit_Assign(self, node): + self.force_store = True + [self.visit(target) for target in node.targets] + self.force_store = False + self.visit(node.value) + + def visit_AugAssign(self, node): + self.force_store = True + self.visit(node.target) + self.force_store = False + self.visit(node.value) + + @staticmethod + def get_call_base(func_node): + if isinstance(func_node, ast.Attribute): + # If the .value is another Attribute, keep digging + if isinstance(func_node.value, ast.Attribute): + return RegionAnalyzer.get_call_base(func_node.value) + # If the .value is a Name, that's our base + elif isinstance(func_node.value, ast.Name): + return func_node.value.id + else: + # Could be something else (lambda, call, etc.) + return None + elif isinstance(func_node, ast.Name): + return None + return None + + @staticmethod + def get_function_name(func_node: ast.Call): + if isinstance(func_node.func, ast.Name): + function_name = func_node.func.id + # Check if it's a method or attribute call + elif isinstance(func_node.func, ast.Attribute): + function_name = func_node.func.attr + else: + function_name = None + return function_name + + def visit_Call(self, node): + base_name = RegionAnalyzer.get_call_base(node.func) + + if isinstance(node.func, ast.Name): + func_name = node.func.id + if func_name in local_closure: + raise DSLAstPreprocessorError( + f"Function `{func_name}` is a closure and is not supported in for/if statements", + filename=file_name, + snippet=ast.unparse(region_node), + ) + + # Classes are mutable by default. Mark them as write. If they are + # dataclass(frozen=True), treat them as read in runtime. + if base_name is not None and base_name not in ("self"): + invoked_args.add(base_name) + + self.generic_visit(node) + + analyzer = RegionAnalyzer() + analyzer.visit(ast.Module(body=node)) + + # If arg is both write and invoke, remove from invoked_args + invoked_args = invoked_args - write_args + + write_args = list(write_args.intersections(active_symbols)) + invoked_args = list(invoked_args.intersections(active_symbols)) + + return write_args + invoked_args, len(write_args) + + def extract_range_args(self, iter_node): + args = iter_node.args + if len(args) == 1: + return ( + self.visit(ast.Constant(value=0)), + self.visit(args[0]), + self.visit(ast.Constant(value=1)), + False, + ) + elif len(args) == 2: + return ( + self.visit(args[0]), + self.visit(args[1]), + self.visit(ast.Constant(value=1)), + False, + ) + elif len(args) == 3: + return self.visit(args[0]), self.visit(args[1]), self.visit(args[2]), True + else: + raise DSLAstPreprocessorError( + "Unsupported number of arguments in range", filename=self.file_name + ) + + def extract_unroll_args(self, iter_node): + keywords = {kw.arg: kw.value for kw in iter_node.keywords} + return ( + keywords.get("unroll", ast.Constant(value=-1)), + keywords.get("unroll_full", ast.Constant(value=False)), + ) + + def issue_deprecation_warning(self, *, message, category, filename, lineno): + warnings.simplefilter("always", category) # turn off filter + warnings.warn_explicit( + message, category=category, filename=filename, lineno=lineno + ) + warnings.simplefilter("default", category) # reset filter + + def extract_prefetch_stages_args(self, iter_node): + keywords = {kw.arg: kw.value for kw in iter_node.keywords} + if "pipelining" in keywords: + self.issue_deprecation_warning( + message="pipelining is deprecated, use prefetch_stages instead", + category=DeprecationWarning, + filename=self.file_name, + lineno=iter_node.lineno, + ) + return keywords.get("pipelining", ast.Constant(value=None)) + return keywords.get("prefetch_stages", ast.Constant(value=None)) + + def create_loop_function( + self, + func_name, + node, + start, + stop, + step, + unroll, + unroll_full, + prefetch_stages, + write_args, + full_write_args_count, + ): + """ + Creates a loop body function with the `loop_selector` decorator. + """ + + func_args = [ast.arg(arg=node.target.id, annotation=None)] + func_args += [ast.arg(arg=var, annotation=None) for var in write_args] + + # Create the loop body + transformed_body = [] + for stmt in node.body: + transformed_stmt = self.visit(stmt) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + transformed_body.extend(transformed_stmt) + else: + transformed_body.append(transformed_stmt) + + # Handle the return for a single iterated argument correctly + if len(write_args) == 0: + transformed_body.append(ast.Return()) + else: + transformed_body.append( + ast.Return( + value=ast.List( + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], + ctx=ast.Load(), + ) + ) + ) + + # Define the decorator with parameters + decorator = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.DECORATOR_FOR_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[start, stop, step], + keywords=[ + ast.keyword(arg="unroll", value=unroll), + ast.keyword(arg="unroll_full", value=unroll_full), + ast.keyword(arg="prefetch_stages", value=prefetch_stages), + ast.keyword( + arg="write_args", + value=self.generate_get_locals_or_none_call(write_args), + ), + ast.keyword( + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), + ), + ast.keyword( + arg="write_args_names", + value=ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ), + ], + ), + node, + ) + + return ast.copy_location( + ast.FunctionDef( + name=func_name, + args=ast.arguments( + posonlyargs=[], + args=func_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=transformed_body, + decorator_list=[decorator], + ), + node, + ) + + def visit_BoolOp(self, node): + # Visit child nodes first + self.generic_visit(node) + + # It is necessary to expand short circuit evaluation explicit here + # Although we do not support inline if-else for IR generation, this is actually evaluated in Python + # So it's fine here + # Transform "and" to "and_" + if isinstance(node.op, ast.And): + # Create an if-else statement in AST form + # if type(lhs) == bool and lhs == False: + # return lhs + # else + # return and_(lhs, rhs) + short_circuit_value = ast.Constant(value=False) + helper_func = self._create_module_attribute( + "and_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True + # Transform "or" to "or_" + elif isinstance(node.op, ast.Or): + # Create an if-else statement in AST form + # if type(lhs) == bool and lhs == True: + # return lhs + # else + # return or_(lhs, rhs) + short_circuit_value = ast.Constant(value=True) + helper_func = self._create_module_attribute( + "or_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True + else: + # BoolOp should be either And or Or + raise DSLAstPreprocessorError( + f"Unsupported boolean operation: {node.op}", + filename=self.file_name, + snippet=ast.unparse(node), + ) + + def short_circuit_eval(value, short_circuit_value): + return ast.BoolOp( + op=ast.And(), + values=[ + ast.Compare( + left=ast.Call( + func=ast.Name(id="type", ctx=ast.Load()), + args=[value], + keywords=[], + ), + ops=[ast.Eq()], + comparators=[ast.Name(id="bool", ctx=ast.Load())], + ), + ast.Compare( + left=value, + ops=[ast.Eq()], + comparators=[short_circuit_value], + ), + ], + ) + + lhs = node.values[0] + + for i in range(1, len(node.values)): + test = short_circuit_eval(lhs, short_circuit_value) + lhs = ast.IfExp( + test=test, + body=lhs, + orelse=ast.Call( + func=helper_func, + args=[lhs, node.values[i]], + keywords=[], + ), + ) + + return ast.copy_location(lhs, node) + + def visit_UnaryOp(self, node): + # Visit child nodes first + self.generic_visit(node) + + # Transform "not" to "~" as we overload __invert__ + if isinstance(node.op, ast.Not): + func_name = self._create_module_attribute( + "not_", + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ) + self.import_top_module = True + return ast.copy_location( + ast.Call(func=func_name, args=[node.operand], keywords=[]), node + ) + + return node + + def _insert_range_value_check(self, node): + """ + Insert a check for range arguments + """ + range_inputs = node.iter.args + check_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "range_value_check", lineno=node.lineno, col_offset=node.col_offset + ), + args=range_inputs, + keywords=[], + ), + node.iter, + ) + node.iter = ast.copy_location( + ast.Call( + func=ast.Name(id="range", ctx=ast.Load()), + args=[ast.Starred(value=check_call, ctx=ast.Load())], + keywords=[], + ), + node.iter, + ) + + def _insert_cf_symbol_check(self, func): + """ + Insert a check for range symbol + """ + check_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "cf_symbol_check", lineno=func.lineno, col_offset=func.col_offset + ), + args=[deepcopy(func)], + keywords=[], + ), + func, + ) + return ast.Expr(check_call) + + def visit_For(self, node): + # For static for loop (for with range_constexpr or not range based for), preprocessor keeps the loop. + range_kind, is_builtin_range, has_keyword = self._get_range_kind(node.iter) + if range_kind == "range_constexpr" or range_kind == None: + self.generic_visit(node) + if range_kind == "range_constexpr": + check_call = self._insert_cf_symbol_check(node.iter.func) + # Rewrite range_constexpr to range + node.iter.func = ast.Name(id="range", ctx=ast.Load()) + self._insert_range_value_check(node) + return [check_call, node] + return node + + active_symbols = self.scope_manager.get_active_symbols() + + with self.scope_manager: + if isinstance(node.target, ast.Name): + self.scope_manager.add_to_scope(node.target.id) + + if range_kind == "range_dynamic": + # Generate a warning + self.issue_deprecation_warning( + message="range_dynamic is deprecated and will be removed in the future, please remove it.", + category=DeprecationWarning, + filename=self.file_name, + lineno=node.iter.lineno, + ) + + warning_call = None + if range_kind == "range" and is_builtin_range and not has_keyword: + # Warn about possible performance regression due to behavior change + warning_call = ast.Expr( + ast.Call( + func=self._create_module_attribute( + "range_perf_warning", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[ + ast.Constant(value=self.file_name), + ast.Constant(value=node.iter.lineno), + ] + + node.iter.args, + keywords=[], + ) + ) + ast.copy_location(warning_call, node.iter) + + is_prefixed_range = range_kind == "range" and not is_builtin_range + check_call = None + if range_kind == "range_dynamic" or is_prefixed_range: + # Insert a check for range symbol + if not is_prefixed_range: + check_call = self._insert_cf_symbol_check(node.iter.func) + else: + # Get toplevel module + check_call = self._insert_cf_symbol_check(node.iter.func.value) + + new_for_node = self.transform_for_loop(node, active_symbols) + if check_call is not None: + new_for_node = [check_call] + new_for_node + + return new_for_node if warning_call is None else [warning_call] + new_for_node + + @staticmethod + def _hoist_expr_to_assignments(expr, name): + return ast.copy_location( + ast.Assign(targets=[ast.Name(id=name, ctx=ast.Store())], value=expr), expr + ) + + def _build_select_and_assign(self, *, name, test, body, orelse, location): + node = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=name, ctx=ast.Store())], + value=ast.IfExp( + test=test, + body=body, + orelse=orelse, + ), + ), + location, + ) + self.generic_visit(node) + return node + + def _handle_negative_step(self, node, start_expr, stop_expr, step_expr): + # hoist start, stop, step to assignments + start_ori_name = f"start_ori_{self.counter}" + start = self._hoist_expr_to_assignments(start_expr, start_ori_name) + stop_ori_name = f"stop_ori_{self.counter}" + stop = self._hoist_expr_to_assignments(stop_expr, stop_ori_name) + step_ori_name = f"step_ori_{self.counter}" + step = self._hoist_expr_to_assignments(step_expr, step_ori_name) + + extra_exprs = [start, stop, step] + + # Handle possible negative step, generates the following code in Python: + # isNegative = step < 0 + isNegative_name = f"isNegative_{self.counter}" + isNegative = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=isNegative_name, ctx=ast.Store())], + value=ast.Compare( + left=ast.Name(id=step_ori_name, ctx=ast.Load()), + ops=[ast.Lt()], + comparators=[ast.Constant(value=0)], + ), + ), + step, + ) + + # start = stop if isNegative else start + start_name = f"start_{self.counter}" + start = self._build_select_and_assign( + name=start_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.Name(id=stop_ori_name, ctx=ast.Load()), + orelse=ast.Name(id=start_ori_name, ctx=ast.Load()), + location=start, + ) + + # stop = start if isNegative else stop + stop_name = f"stop_{self.counter}" + stop = self._build_select_and_assign( + name=stop_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.Name(id=start_ori_name, ctx=ast.Load()), + orelse=ast.Name(id=stop_ori_name, ctx=ast.Load()), + location=stop, + ) + + # step = -step if isNegative else step + step_name = f"step_{self.counter}" + step = self._build_select_and_assign( + name=step_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.UnaryOp( + op=ast.USub(), operand=ast.Name(id=step_ori_name, ctx=ast.Load()) + ), + orelse=ast.Name(id=step_ori_name, ctx=ast.Load()), + location=step, + ) + + # offset = start + stop if isNegative else 0 + offset_name = f"offset_{self.counter}" + offset = self._build_select_and_assign( + name=offset_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.BinOp( + op=ast.Add(), + left=ast.Name(id=start_name, ctx=ast.Load()), + right=ast.Name(id=stop_name, ctx=ast.Load()), + ), + orelse=ast.Constant(value=0), + location=node, + ) + + extra_exprs.append(isNegative) + extra_exprs.append(start) + extra_exprs.append(stop) + extra_exprs.append(step) + extra_exprs.append(offset) + + # Add this to begining of loop body + # for i in range(start, stop, step): + # i = offset - i if isNegative else i + assert isinstance(node.target, ast.Name) + + target_name = node.target.id + target = self._build_select_and_assign( + name=target_name, + test=ast.Name(id=isNegative_name, ctx=ast.Load()), + body=ast.BinOp( + op=ast.Sub(), + left=ast.Name(id=offset_name, ctx=ast.Load()), + right=ast.Name(id=target_name, ctx=ast.Load()), + ), + orelse=ast.Name(id=target_name, ctx=ast.Load()), + location=node.target, + ) + + node.body.insert(0, target) + + return ( + ast.Name(id=start_name, ctx=ast.Load()), + ast.Name(id=stop_name, ctx=ast.Load()), + ast.Name(id=step_name, ctx=ast.Load()), + extra_exprs, + ) + + def transform_for_loop(self, node, active_symbols): + # Check for early exit and raise exception + self.check_early_exit(node, "for") + if node.orelse: + raise DSLAstPreprocessorError( + "dynamic for loop with else is not supported", + filename=self.file_name, + snippet=ast.unparse(node), + ) + + # Get loop target variable name + target_var_name = None + target_var_is_active_before_loop = False + if isinstance(node.target, ast.Name): + target_var_name = node.target.id + for active_symbol in active_symbols: + if target_var_name in active_symbol: + target_var_is_active_before_loop = True + active_symbols.remove(active_symbol) + break + + # Add necessary exprs to handle this + if target_var_is_active_before_loop: + # Initialize an extra loop carried variable + loop_carried_var_name = f"loop_carried_var_{self.counter}" + pre_loop_expr = ast.copy_location( + ast.Assign( + targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], + value=ast.Name(id=target_var_name, ctx=ast.Load()), + ), + node, + ) + # append an extra assignment to the loop carried variable + node.body.append( + ast.copy_location( + ast.Assign( + targets=[ast.Name(id=loop_carried_var_name, ctx=ast.Store())], + value=ast.Name(id=target_var_name, ctx=ast.Load()), + ), + node, + ) + ) + active_symbols.append({loop_carried_var_name}) + + start_expr, stop_expr, step_expr, has_step = self.extract_range_args(node.iter) + unroll, unroll_full = self.extract_unroll_args(node.iter) + prefetch_stages = self.extract_prefetch_stages_args(node.iter) + write_args, full_write_args_count = self.analyze_region_variables( + node, active_symbols + ) + + if has_step and self.client_module_name[0] == "cutlass": + start, stop, step, exprs = self._handle_negative_step( + node, start_expr, stop_expr, step_expr + ) + else: + start, stop, step, exprs = start_expr, stop_expr, step_expr, [] + + if target_var_is_active_before_loop: + exprs.append(pre_loop_expr) + + func_name = f"loop_body_{self.counter}" + self.counter += 1 + + func_def = self.create_loop_function( + func_name, + node, + start, + stop, + step, + unroll, + unroll_full, + prefetch_stages, + write_args, + full_write_args_count, + ) + + assign = self.create_cf_call(func_name, write_args, node) + + # This should work fine as it modifies the AST structure + exprs = exprs + [func_def] + assign + + if target_var_is_active_before_loop: + # Create a new assignment to the target variable + exprs.append( + ast.copy_location( + ast.Assign( + targets=[ast.Name(id=target_var_name, ctx=ast.Store())], + value=ast.Name(id=loop_carried_var_name, ctx=ast.Load()), + ), + node, + ) + ) + + return exprs + + def visit_Assert(self, node): + test = self.visit(node.test) + + args = [ast.keyword(arg="test", value=test)] + if node.msg: + msg = self.visit(node.msg) + args.append(ast.keyword(arg="msg", value=msg)) + + # Rewrite to assert_executor(test, msg) + new_node = ast.Expr( + ast.Call( + func=self._create_module_attribute( + self.ASSERT_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=args, + ) + ) + + # Propagate line number from original node to new node + ast.copy_location(new_node, node) + return new_node + + def visit_Call(self, node): + func = node.func + # Visit args and kwargs + node.args = [self.visit(arg) for arg in node.args] + node.keywords = [self.visit(kwarg) for kwarg in node.keywords] + + # Rewrite call to some built-in functions + if isinstance(func, ast.Name): + # Check if the function is 'bool' + if func.id == "bool": + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.BOOL_CAST, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node.args[0]], + keywords=[], + ), + node, + ) + elif func.id in ["any", "all"]: + helper_func = ( + self.ANY_EXECUTOR if func.id == "any" else self.ALL_EXECUTOR + ) + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + helper_func, lineno=node.lineno, col_offset=node.col_offset + ), + args=[node.args[0]], + keywords=[], + ), + node, + ) + elif func.id in ["min", "max"]: + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + func.id, + top_module_name="cutlass", + submodule_name=None, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node.args[0], node.args[1]], + keywords=[], + ), + node, + ) + elif isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name): + def create_downcast_call(arg): + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.IMPLICIT_DOWNCAST_NUMERIC_TYPE, + submodule_name="typing", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[arg], + keywords=[], + ), + arg, + ) + module = self.function_globals.get(func.value.id) + if isinstance(module, ModuleType) and module.__package__.endswith( + "._mlir.dialects" + ): + # Check if argument is Numeric, if so, call ir_value() + args = [] + for arg in node.args: + args.append(create_downcast_call(arg)) + kwargs = [] + for kwarg in node.keywords: + kwargs.append( + ast.copy_location( + ast.keyword( + arg=kwarg.arg, + value=create_downcast_call(kwarg.value), + ), + kwarg, + ) + ) + return ast.copy_location( + ast.Call(func=func, args=args, keywords=kwargs), node + ) + else: + node.func = self.visit(node.func) + + return node + + def visit_ClassDef(self, node): + self.class_name = node.name + self.generic_visit(node) + self.class_name = None + return node + + def _visit_target(self, target): + if isinstance(target, ast.Name): + self.scope_manager.add_to_scope(target.id) + elif isinstance(target, ast.Tuple): + for t in target.elts: + if isinstance(t, ast.Name): + self.scope_manager.add_to_scope(t.id) + + def visit_Assign(self, node): + for target in node.targets: + self._visit_target(target) + self.generic_visit(node) + return node + + def visit_AugAssign(self, node): + self._visit_target(node.target) + self.generic_visit(node) + return node + + def visit_Name(self, node): + isLoad = isinstance(node.ctx, ast.Load) + if node.id in ["max", "min", "any", "all"] and isLoad: + return ast.copy_location( + ast.Call( + func=self._create_module_attribute( + "redirect_builtin_function", + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[node], + keywords=[], + ), + node, + ) + elif node.id == "_" and isLoad: + raise DSLAstPreprocessorError("Read '_' is not allowed") + else: + self.generic_visit(node) + return node + + def check_decorator(self, node: ast.AST) -> bool: + """ + Check if the function has the correct decorator for preprocessing. + """ + if not isinstance(node, ast.FunctionDef): + return False + decorator_list = node.decorator_list + if len(decorator_list) == 0: + return False + + for d in decorator_list: + if isinstance(d, ast.Call): + if isinstance(d.func, ast.Attribute): + if d.func.attr in ["jit", "kernel"]: + if d.keywords == []: + return True + for keyword in d.keywords: + if keyword.arg == "preprocess": + try: + if isinstance(keyword.value, ast.Constant): + return keyword.value.value + else: + return ast.literal_eval(keyword.value) + except: + pass + + elif isinstance(d, ast.Attribute): + if d.attr in ["jit", "kernel"]: + return True + + return False + + def remove_dsl_decorator(self, decorator_list): + """ + Remove .jit and .kernel decorators + The decorator can be in two forms: + - @jit(...) + - @jit + """ + new_decorator_list = [] + decorator_names = ["jit", "kernel"] + for d in decorator_list: + is_jit_or_kernel = False + if isinstance(d, ast.Call): + if isinstance(d.func, ast.Attribute): + if d.func.attr in decorator_names: + is_jit_or_kernel = True + elif isinstance(d, ast.Attribute): + if d.attr in decorator_names: + is_jit_or_kernel = True + + if not is_jit_or_kernel: + new_decorator_list.append(d) + return new_decorator_list + + def visit_FunctionDef(self, node): + with self.scope_manager: + self.function_counter += 1 + self.function_name = node.name + if self.function_depth > 0: + self.local_closures.add(node.name) + + self.function_depth += 1 + + # Add function name and arguments + self.scope_manager.add_to_scope(node.name) + for arg in node.args.args: + self.scope_manager.add_to_scope(arg.arg) + + self.generic_visit(node) + + self.function_depth -= 1 + + # Remove .jit and .kernel decorators + node.decorator_list = self.remove_dsl_decorator(node.decorator_list) + return node + + def visit_With(self, node): + with self.scope_manager: + for item in node.items: + if isinstance(item.optional_vars, ast.Name): + self.scope_manager.add_to_scope(item.optional_vars.id) + self.generic_visit(node) + + return node + + def visit_While(self, node): + # Constexpr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + check = self._insert_cf_symbol_check(node.test.func) + return [check, node] + + active_symbols = self.scope_manager.get_active_symbols() + + with self.scope_manager: + # Check for early exit and raise exception + self.check_early_exit(node, "while") + + write_args, full_write_args_count = self.analyze_region_variables( + node, active_symbols + ) + func_name = f"while_region_{self.counter}" + self.counter += 1 + + func_def = self.create_while_function( + func_name, node, write_args, full_write_args_count + ) + assign = self.create_cf_call(func_name, write_args, node) + + return [func_def] + assign + + def visit_Try(self, node): + with self.scope_manager: + self.generic_visit(node) + return node + + def visit_ExceptHandler(self, node): + with self.scope_manager: + if node.name: # Exception variable + self.scope_manager.add_to_scope(node.name) + self.generic_visit(node) + return node + + def create_cf_call(self, func_name, yield_args, node): + """Creates the assignment statement for the if function call""" + if not yield_args: + return [ + ast.copy_location( + ast.Expr(value=ast.Name(id=func_name, ctx=ast.Load())), node + ) + ] + has_self = False + for i, arg in enumerate(yield_args): + if arg == "self": + has_self = True + yield_args[i] = "yield_self" + break + if len(yield_args) == 1: + assign = ast.Assign( + targets=[ast.Name(id=yield_args[0], ctx=ast.Store())], + value=ast.Name(id=func_name, ctx=ast.Load()), + ) + else: + assign = ast.Assign( + targets=[ + ast.Tuple( + elts=[ast.Name(id=var, ctx=ast.Store()) for var in yield_args], + ctx=ast.Store(), + ) + ], + value=ast.Name(id=func_name, ctx=ast.Load()), + ) + + if has_self: + fix_self = ast.Expr( + value=ast.Call( + func=self._create_module_attribute( + "copy_members", lineno=node.lineno, col_offset=node.col_offset + ), + args=[ + ast.Name(id="self", ctx=ast.Load()), + ast.Name(id="yield_self", ctx=ast.Load()), + ], + keywords=[], + ) + ) + return [ast.copy_location(assign, node), ast.copy_location(fix_self, node)] + else: + return [ast.copy_location(assign, node)] + + def visit_IfExp(self, node): + """ + Visits an inline if-else expression (ternary operator). + This is the Python equivalent of `x if condition else y`. + """ + self.generic_visit(node) + # Emit + # node if type(pred) == bool else select_(pred, body, orelse) + # so if pred is a python bool, use python to short-circuit and avoid emit arith.select + self.import_top_module = True + return ast.copy_location( + ast.IfExp( + test=ast.Compare( + left=ast.Call( + func=ast.Name(id="type", ctx=ast.Load()), + args=[node.test], + keywords=[], + ), + ops=[ast.Eq()], + comparators=[ast.Name(id="bool", ctx=ast.Load())], + ), + body=node, # Original ternary expression + orelse=ast.Call( + func=self._create_module_attribute( + "select_", top_module_name="cutlass", submodule_name=None + ), + args=[ + node.test, + node.body, + node.orelse, + ], + keywords=[], + ), + ), + node, + ) + + cmpops = { + "Eq": "==", + "NotEq": "!=", + "Lt": "<", + "LtE": "<=", + "Gt": ">", + "GtE": ">=", + "Is": "is", + "IsNot": "is not", + "In": "in", + "NotIn": "not in", + } + def compare_ops_to_str(self, node): + names = [ + ast.Constant(value=self.cmpops[op.__class__.__name__]) for op in node.ops + ] + return ast.List(elts=names, ctx=ast.Load()) + + def visit_Compare(self, node): + self.generic_visit(node) + + comparator_strs = self.compare_ops_to_str(node) + + keywords = [ + ast.keyword(arg="left", value=node.left), + ast.keyword( + arg="comparators", value=ast.List(elts=node.comparators, ctx=ast.Load()) + ), + ast.keyword(arg="ops", value=comparator_strs), + ] + + call = ast.copy_location( + ast.Call( + func=self._create_module_attribute(self.COMPARE_EXECUTOR), + args=[], + keywords=keywords, + ), + node, + ) + + return call + + def visit_If(self, node): + # const_expr doesn't get preprocessed + if self.is_node_constexpr(node): + self.generic_visit(node) + check = self._insert_cf_symbol_check(node.test.func) + return [check, node] + + active_symbols = self.scope_manager.get_active_symbols() + with self.scope_manager: + # Check for early exit and raise exception + self.check_early_exit(node, "if") + + yield_args, full_write_args_count = self.analyze_region_variables( + node, active_symbols + ) + func_name = f"if_region_{self.counter}" + self.counter += 1 + + func_def = self.create_if_function( + func_name, node, yield_args, full_write_args_count + ) + assign = self.create_cf_call(func_name, yield_args, node) + + return [func_def] + assign + + def generate_get_locals_or_none_call(self, write_args): + return ast.Call( + func=self._create_module_attribute("get_locals_or_none"), + args=[ + ast.Call( + func=ast.Name(id="locals", ctx=ast.Load()), args=[], keywords=[] + ), + ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ], + keywords=[], + ) + + def create_if_function(self, func_name, node, write_args, full_write_args_count): + test_expr = self.visit(node.test) + pred_name = self.make_func_param_name("pred", write_args) + func_args = [ast.arg(arg=pred_name, annotation=None)] + func_args += [ast.arg(arg=var, annotation=None) for var in write_args] + func_args_then_else = [ast.arg(arg=var, annotation=None) for var in write_args] + + then_body = [] + for stmt in node.body: + transformed_stmt = self.visit(stmt) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + then_body.extend(transformed_stmt) + else: + then_body.append(transformed_stmt) + + # Create common return list for all blocks + return_list = ast.List( + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], + ctx=ast.Load(), + ) + + # Create common function arguments + func_decorator_arguments = ast.arguments( + posonlyargs=[], args=func_args, kwonlyargs=[], kw_defaults=[], defaults=[] + ) + func_then_else_arguments = ast.arguments( + posonlyargs=[], + args=func_args_then_else, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + + then_block_name = f"then_block_{self.counter}" + else_block_name = f"else_block_{self.counter}" + elif_region_name = f"elif_region_{self.counter}" + self.counter += 1 + + # Create then block + then_block = ast.copy_location( + ast.FunctionDef( + name=then_block_name, + args=func_then_else_arguments, + body=then_body + [ast.Return(value=return_list)], + decorator_list=[], + ), + node, + ) + + # Decorator keywords + decorator_keywords = [ + ast.keyword( + arg="pred", value=test_expr + ), # ast.Name(id="pred", ctx=ast.Load()) + ast.keyword( + arg="write_args", + value=self.generate_get_locals_or_none_call(write_args), + ), + ] + + # Create decorator + decorator = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.DECORATOR_IF_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[], + keywords=decorator_keywords, + ), + node, + ) + + # Executor keywords + execute_keywords = [ + ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), + ast.keyword( + arg="write_args", + value=ast.List( + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], + ctx=ast.Load(), + ), + ), + ast.keyword( + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), + ), + ast.keyword( + arg="write_args_names", + value=ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ), + ast.keyword( + arg="then_block", value=ast.Name(id=then_block_name, ctx=ast.Load()) + ), + ] + + # Handle different cases + if not write_args and node.orelse == []: + # No write_args case - only then_block needed + execute_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=execute_keywords, + ), + node, + ) + func_body = [then_block, ast.Return(value=execute_call)] + else: + # Create else block based on node.orelse + if node.orelse: + if len(node.orelse) == 1 and isinstance(node.orelse[0], ast.If): + # Handle elif case + elif_node = node.orelse[0] + nested_if_name = elif_region_name + # Recursion for nested elif + nested_if = self.create_if_function( + nested_if_name, elif_node, write_args, full_write_args_count + ) + else_block = ast.FunctionDef( + name=else_block_name, + args=func_then_else_arguments, + body=[ + nested_if, + ast.Return( + value=ast.Name(id=nested_if_name, ctx=ast.Load()) + ), + ], + decorator_list=[], + ) + else: + + else_body = [] + for stmt in node.orelse: + transformed_stmt = self.visit( + stmt + ) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + else_body.extend(transformed_stmt) + else: + else_body.append(transformed_stmt) + + # Regular else block + else_block = ast.FunctionDef( + name=else_block_name, + args=func_then_else_arguments, + body=else_body + [ast.Return(value=return_list)], + decorator_list=[], + ) + else: + # Default else block + else_block = ast.FunctionDef( + name=else_block_name, + args=func_then_else_arguments, + body=[ast.Return(value=return_list)], + decorator_list=[], + ) + + # Add else_block to execute keywords + execute_keywords.append( + ast.keyword( + arg="else_block", value=ast.Name(id=else_block_name, ctx=ast.Load()) + ) + ) + + execute_call = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.IF_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=execute_keywords, + ), + node, + ) + func_body = [ + then_block, + ast.copy_location(else_block, node), + ast.Return(value=execute_call), + ] + + return ast.copy_location( + ast.FunctionDef( + name=func_name, + args=func_decorator_arguments, + body=func_body, + decorator_list=[decorator], + ), + node, + ) + + def create_while_function(self, func_name, node, write_args, full_write_args_count): + """Create a while function that looks like: + + @while_selector(pred, write_args=[]) + def while_region(pred, write_args): + def while_before_block(*write_args): + # Note that during eval of pred can possibly alter yield_args + return *pred, write_args + def while_after_block(*write_args): + ...loop_body_transformed... + return write_args + return self.while_executor(pred, write_args, + while_before_block, while_after_block, constexpr) + write_args = while_region(pred, write_args) + + Which will later be executed as psuedo-code: + + # Dynamic mode: + scf.WhileOp(types(write_args), write_args) + with InsertionPoint(before_block): + cond, write_args = while_before_block(*write_args) + scf.ConditionOp(cond, write_args) + with InsertionPoint(after_block): + write_args = while_after_block(write_args) + scf.YieldOp(write_args) + return while_op.results_ + + # Const mode: + cond, write_args = while_before_block(write_args) + while pred: + write_args = body_block(write_args) + cond, write_args = while_before_block(write_args) + return write_args + """ + test_expr = self.visit(node.test) + pred_name = self.make_func_param_name("pred", write_args) + + # Section: decorator construction + decorator_keywords = [ + ast.keyword(arg="pred", value=test_expr), + ast.keyword( + arg="write_args", + value=self.generate_get_locals_or_none_call(write_args), + ), + ] + decorator = ast.copy_location( + ast.Call( + func=self._create_module_attribute( + self.DECORATOR_WHILE_STATEMENT, + lineno=node.lineno, + col_offset=node.col_offset, + ), + args=[], + keywords=decorator_keywords, + ), + node, + ) + + # Section: Shared initialization for before and after blocks + while_before_block_name = f"while_before_block_{self.counter}" + while_after_block_name = f"while_after_block_{self.counter}" + self.counter += 1 + block_args_args = [ast.arg(arg=var, annotation=None) for var in write_args] + block_args = ast.arguments( + posonlyargs=[], + args=block_args_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + + yield_args_ast_name_list = ast.List( + elts=[ast.Name(id=var, ctx=ast.Load()) for var in write_args], + ctx=ast.Load(), + ) + + # Section: while_before_block FunctionDef, which contains condition + while_before_return_list = ast.List( + elts=[test_expr, yield_args_ast_name_list], + ctx=ast.Load(), + ) + while_before_stmts = [ast.Return(value=while_before_return_list)] + while_before_block = ast.copy_location( + ast.FunctionDef( + name=while_before_block_name, + args=block_args, + body=while_before_stmts, + decorator_list=[], + ), + test_expr, + ) + + # Section: while_after_block FunctionDef, which contains loop body + while_after_stmts = [] + for stmt in node.body: + transformed_stmt = self.visit(stmt) # Recursively visit inner statements + if isinstance(transformed_stmt, list): + while_after_stmts.extend(transformed_stmt) + else: + while_after_stmts.append(transformed_stmt) + while_after_stmts.append(ast.Return(value=yield_args_ast_name_list)) + + while_after_block = ast.copy_location( + ast.FunctionDef( + name=while_after_block_name, + args=block_args, + body=while_after_stmts, + decorator_list=[], + ), + node, + ) + + # Section: Execute via executor + execute_keywords = [ + ast.keyword(arg="pred", value=ast.Name(id=pred_name, ctx=ast.Load())), + ast.keyword( + arg="write_args", + value=ast.List( + elts=[ast.Name(id=arg, ctx=ast.Load()) for arg in write_args], + ctx=ast.Load(), + ), + ), + ast.keyword( + arg="full_write_args_count", + value=ast.Constant(value=full_write_args_count), + ), + ast.keyword( + arg="while_before_block", + value=ast.Name(id=while_before_block_name, ctx=ast.Load()), + ), + ast.keyword( + arg="while_after_block", + value=ast.Name(id=while_after_block_name, ctx=ast.Load()), + ), + ast.keyword( + arg="write_args_names", + value=ast.List( + elts=[ast.Constant(value=arg) for arg in write_args], + ctx=ast.Load(), + ), + ), + ] + + execute_call = ast.Call( + func=self._create_module_attribute( + self.WHILE_EXECUTOR, lineno=node.lineno, col_offset=node.col_offset + ), + args=[], + keywords=execute_keywords, + ) + + # Putting everything together, FunctionDef for while_region + func_args_args = [ast.arg(arg=pred_name, annotation=None)] + func_args_args += [ast.arg(arg=var, annotation=None) for var in write_args] + func_args = ast.arguments( + posonlyargs=[], + args=func_args_args, + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ) + + return ast.copy_location( + ast.FunctionDef( + name=func_name, + args=func_args, + body=[ + while_before_block, + while_after_block, + ast.Return(value=execute_call), + ], + decorator_list=[decorator], + ), + node, + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5d9234f2fe760ba0026a63c139b8535dd777f621 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/cache_helpers.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides jit cache load/dump helper functions +""" + +import os +import uuid +import random +import tempfile +import pwd +import time +from pathlib import Path +import hashlib + +from .utils.logger import log +from .jit_executor import JitExecutor + +from .._mlir import ir + +# ============================================================================= +# Jit Cache Helper functions +# ============================================================================= + + +def get_current_user(): + # Try to get the user from the environment variable first + user = os.getenv("USER") or os.getenv("USERNAME") + if not user: + # Fallback for Unix-like systems + user = pwd.getpwuid(os.getuid()).pw_name + return user + + +try: + default_generated_ir_path = f"/tmp/{get_current_user()}/cutlass_python_cache/" +except Exception as e: + # If all else fails, provide a default fallback path + default_generated_ir_path = "/tmp/cutlass_python_cache/" + print(f"Could not determine user, using default path. Error: {e}") + + +def load_ir(file, asBytecode=False): + """Load generated IR from a file.""" + assert "mlir" in file + func_name = file.split(".mlir")[0].split("dsl_")[-1] + with ir.Context() as ctx: + with open(file, "rb" if asBytecode else "r") as f: + module = ir.Module.parse(f.read()) + + return func_name, module + + +def make_unique_filename(fpath: Path, new_ext: str = None) -> Path: + """Generate a unique filename with an optional new extension.""" + random_part = random.randint(0, 999999) + timestamp = time.time() + hash_input = f"{fpath}_{timestamp}_{random_part}".encode() + hash_code = hashlib.md5(hash_input).hexdigest()[:16] # Shorter hash for readability + stem_with_hash = f"{fpath.stem}_{hash_code}" + return fpath.with_name(stem_with_hash).with_suffix(new_ext or fpath.suffix) + + +def save_ir( + dsl_name: str, + module: object, + fname: str, + isTemp: bool = False, + asBytecode: bool = False, +) -> str: + """Save generated IR to a file.""" + initial_name = f"{dsl_name.lower()}_{fname}.mlir" + save_path = Path(tempfile.gettempdir() if isTemp else os.getcwd()) + save_fname = save_path / initial_name + # Random ID to avoid any collisions + rnd_id = str(uuid.uuid4()) + pid = os.getpid() + # use temp dir to be robust against program interruptions + temp_dir = os.path.join(save_path, f"tmp.pid_{pid}_{rnd_id}") + # If the process exits abnormally, may leave a temporary folder. Needs to be removed manually. + os.makedirs(temp_dir, exist_ok=False) + temp_fname = os.path.join(temp_dir, initial_name) + + if asBytecode: + with open(temp_fname, "wb") as f: + module.operation.write_bytecode(f) + else: + with open(temp_fname, "w") as f: + print(module, file=f) + # os.replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_fname, save_fname) + os.removedirs(temp_dir) + log().debug("Generated IR saved into %s", save_fname) + return save_fname + + +def check_func_name(jit_cache, func_name): + if not func_name in jit_cache: + jit_cache[func_name] = JitExecutor(None, None, None, None, None, None) + return jit_cache + + +def load_cache_from_path(dsl_name, cache_limit, path=default_generated_ir_path): + """Load cache from a directory path.""" + if not os.path.exists(path): + return dict() + files = os.listdir(path) + jit_cache = dict() + try: + for idx, file in enumerate(files): + if idx >= int(cache_limit): + break + # identify dsl prefix + if not file.startswith(f"{dsl_name.lower()}"): + continue + if ".mlir" in file: + func_name, ir_module = load_ir( + os.path.join(path, file), asBytecode=True + ) + jit_cache = check_func_name(jit_cache, func_name) + jit_cache[func_name].ir_module = ir_module + except Exception as e: + print(f"{dsl_name} failed with loading generated IR cache.", e) + jit_cache = dict() + return jit_cache + + +def dump_cache_to_path( + dsl_name, jit_cache, cache_limit, path=default_generated_ir_path +): + log().info("JIT cache : dumping [%s] items=[%s]", dsl_name, len(jit_cache)) + os.makedirs(path, exist_ok=True) + original_path = os.getcwd() + try: + os.chdir(path) + for idx, [key, value] in enumerate(jit_cache.items()): + if idx >= int(cache_limit): + break + save_ir(dsl_name, value.ir_module, key, asBytecode=True) + except Exception as e: + print(f"{dsl_name} failed with caching generated IR", e) + finally: + os.chdir(original_path) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3cf413ed5018f99ae748cb2eb1883992f27a87b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/common.py @@ -0,0 +1,268 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import os +from typing import Any, Dict, Iterable, Optional, Union + +""" +This module provides a Exception classes DSL class for any Dialect. +""" + + +# Add color codes at the top of the file after imports +class Colors: + """ANSI color codes for error messages""" + + RED = "\033[91m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + GREEN = "\033[92m" + BOLD = "\033[1m" + RESET = "\033[0m" + + +# ============================================================================= +# DSL Exceptions +# ============================================================================= + + +class DSLBaseError(Exception): + """ + Base exception for DSL-related errors. + Provides optional contextual metadata to aid in debugging. + """ + + def __init__( + self, + message: str, + line: Optional[int] = None, + snippet: Optional[str] = None, + filename: Optional[str] = None, + error_code: Optional[Union[str, int]] = None, + context: Optional[Union[Dict[str, Any], str]] = None, + suggestion: Optional[str] = None, + cause: Optional[BaseException] = None, + ) -> None: + self.message = message + self.line = line + self.filename = filename + self.snippet = snippet + self.error_code = error_code + self.context = context + self.suggestion = suggestion + self.cause = cause + + super().__init__(self._format_message()) + + def _format_message(self): + """ + Formats the complete error message with available metadata. + Override this in subclasses if you want to change formatting logic. + """ + parts = [f"{self.__class__.__name__}: {self.message}"] + + if self.error_code is not None: + parts.append(f"{Colors.BOLD}Error Code:{Colors.RESET} {self.error_code}\n") + + if self.line is not None: + parts.append(f" Line: {self.line}") + + if self.filename is not None: + parts.append(f" File: {self.filename}") + + if self.snippet: + # Optionally truncate long snippets for readability + parts.append(f" Snippet: \n {self.snippet}") + + if self.cause: + parts.append(f" Caused exception: {self.cause}") + + if self.context: + if isinstance(self.context, dict): + parts.append(f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET}\n") + for key, value in self.context.items(): + parts.append(f" {key}: {value}") + else: + parts.append( + f"{Colors.BLUE}🔍 Additional Context:{Colors.RESET} {self.context}" + ) + + if self.suggestion: + parts.append(f"{Colors.GREEN}💡 Suggestions:{Colors.RESET}") + if isinstance(self.suggestion, (list, tuple)): + for suggestion in self.suggestion: + parts.append(f" {Colors.GREEN}{suggestion}{Colors.RESET}") + else: + parts.append(f" {self.suggestion}") + + return "\n".join(parts) + + +class DSLRuntimeError(DSLBaseError): + """ + Raised when an error occurs during JIT-time code generation in the DSL. + """ + + # Inherits all logic from DSLBaseError; override methods if you need + # specialized behavior or formatting for runtime errors. + pass + + +def _get_friendly_cuda_error_message(error_code, error_name): + # Avoid circular dependency + from .runtime.cuda import get_device_info + + """Get a user-friendly error message for common CUDA errors.""" + # Strip the byte string markers if present + if isinstance(error_name, bytes): + error_name = error_name.decode("utf-8") + elif ( + isinstance(error_name, str) + and error_name.startswith("b'") + and error_name.endswith("'") + ): + error_name = error_name[2:-1] + + # Add target architecture info + target_arch = os.getenv("CUTE_DSL_ARCH", "unknown") + + error_messages = { + "CUDA_ERROR_INVALID_SOURCE": ( + f"{Colors.RED}❌ Failed to load CUDA kernel - likely architecture mismatch.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_NO_BINARY_FOR_GPU": ( + f"{Colors.RED}❌ CUDA kernel not compatible with your GPU.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_OUT_OF_MEMORY": ( + f"{Colors.RED}💾 CUDA out of memory error.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_INVALID_DEVICE": ( + f"{Colors.RED}❌ Invalid CUDA device.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_NOT_INITIALIZED": ( + f"{Colors.RED}❌ CUDA context not initialized.{Colors.RESET}\n\n" + ), + "CUDA_ERROR_INVALID_VALUE": ( + f"{Colors.RED}⚠️ Invalid parameter passed to CUDA operation.{Colors.RESET}\n\n" + f"{Colors.YELLOW}This is likely a bug - please report it with:{Colors.RESET}" + ), + } + + error_suggestions = { + "CUDA_ERROR_INVALID_SOURCE": ( + f"1. Ensure env CUTE_DSL_ARCH matches your GPU architecture", + f"2. Clear the compilation cache and regenerate the kernel", + f"3. Check CUDA toolkit installation", + ), + "CUDA_ERROR_NO_BINARY_FOR_GPU": ( + f"Set env CUTE_DSL_ARCH to match your GPU architecture", + ), + "CUDA_ERROR_OUT_OF_MEMORY": ( + f"1. Reduce batch size", + f"2. Reduce model size", + f"3. Free unused GPU memory", + ), + "CUDA_ERROR_INVALID_DEVICE": ( + f"1. Check if CUDA device is properly initialized", + f"2. Verify GPU is detected: nvidia-smi", + f"3. Check CUDA_VISIBLE_DEVICES environment variable", + ), + "CUDA_ERROR_NOT_INITIALIZED": ( + f"1. Check CUDA driver installation", + f"2. call `cuda.cuInit(0)` before any other CUDA operation", + f"3. Run nvidia-smi to confirm GPU status", + ), + "CUDA_ERROR_INVALID_VALUE": ( + f"1. Your GPU model", + f"2. SM ARCH setting", + f"3. Steps to reproduce", + ), + } + + message = error_messages.get( + error_name, f"{Colors.RED}Unknown CUDA error{Colors.RESET}" + ) + + # Add debug information + debug_info = f"\n- {Colors.BOLD}Error name: {error_name}\n" + debug_info += f"- CUDA_TOOLKIT_PATH: {os.getenv('CUDA_TOOLKIT_PATH', 'not set')}\n" + debug_info += ( + f"- Target SM ARCH: {os.getenv('CUTE_DSL_ARCH', 'not set')}{Colors.RESET}\n" + ) + + try: + # Get GPU information using CUDA Python API + debug_info += f"\n{Colors.BLUE}📊 GPU Information:{Colors.RESET}\n" + gpu_info = get_device_info() + debug_info += gpu_info.pretty_str() + + if target_arch and gpu_info.compatible_archs: + debug_info += f"\n{Colors.BOLD}Compatibility Check:{Colors.RESET}\n" + + if target_arch not in gpu_info.compatible_archs: + debug_info += ( + f"{Colors.RED}❌ Error: Target SM ARCH {target_arch} is not compatible\n" + f"💡 Please use one of SM ARCHs: " + f"{Colors.GREEN}{', '.join(gpu_info.compatible_archs or [])}{Colors.RESET}\n" + ) + elif target_arch != gpu_info.sm_arch: + debug_info += ( + f"{Colors.YELLOW}⚠️ Warning: Using compatible but non-optimal architecture\n" + f"• Current: {target_arch}\n" + f"• Recommended: {Colors.GREEN}{gpu_info.sm_arch}{Colors.RESET} (native)\n" + ) + else: + debug_info += f"{Colors.GREEN}✓ Using optimal architecture: {gpu_info.sm_arch}{Colors.RESET}\n" + + except Exception as e: + debug_info += ( + f"\n{Colors.YELLOW}ℹ️ Could not retrieve GPU info: {str(e)}{Colors.RESET}" + ) + + return message, debug_info, error_suggestions.get(error_name, "") + + +class DSLCudaRuntimeError(DSLBaseError): + """ + Raised when an error occurs during CUDA runtime code generation in the DSL. + """ + + # Inherits all logic from DSLRuntimeError; override methods if you need + # specialized behavior or formatting for runtime errors. + def __init__(self, error_code, error_name) -> None: + self._error_code = error_code + self._error_name = error_name + message, debug_info, suggestion = _get_friendly_cuda_error_message( + error_code, error_name + ) + + super().__init__( + message, error_code=error_code, context=debug_info, suggestion=suggestion + ) + + +class DSLAstPreprocessorError(DSLBaseError): + """ + Raised when an error occurs during AST preprocessing or visiting in the DSL. + """ + + # Same approach: You could override _format_message if you want + # to emphasize AST node details or anything specific to preprocessing. + pass + + +class DSLNotImplemented(DSLBaseError): + """ + Raised when a feature of the DSL is not implemented yet. + """ + + # Useful for stubs in your DSL that you plan to implement in the future. + pass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b2da07ac9ac104f56c16a5cfcbbf01f01ee786 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/compiler.py @@ -0,0 +1,288 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a class that compiles generated IR using MLIR's PassManager +and executes it using MLIR's ExecutionEngine. + +""" + +from typing import Sequence, Optional, Tuple +import os +import sys +import inspect +import argparse +from .common import DSLRuntimeError +from .utils.logger import log + +_SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(_SCRIPT_PATH) + +from .._mlir import ir + + +# ============================================================================= +# Compiler Class +# ============================================================================= + + +class CompilationError(RuntimeError): + """Custom error class for compilation failures""" + + # Add ANSI color codes + RED = "\033[91m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + GREEN = "\033[92m" + BOLD = "\033[1m" + RESET = "\033[0m" + + def __init__( + self, + message: str, + nvvm_error: Optional[str] = None, + ir_context: Optional[str] = None, + cuda_toolkit: Optional[str] = None, + arch: Optional[str] = None, + ): + self.nvvm_error = nvvm_error + self.ir_context = ir_context + self.cuda_toolkit = cuda_toolkit + self.arch = arch + # Call parent with formatted error to avoid showing class name + super().__init__("") # Empty string to avoid class name + # Store formatted error for str() representation + self._formatted_error = self._format_error() + + def __str__(self) -> str: + """Override string representation to avoid showing class name""" + return self._formatted_error + + def __repr__(self) -> str: + """Override repr representation to avoid showing class name""" + return self._formatted_error + + def _format_error(self) -> str: + if not self.nvvm_error: + return str(self.args[0]) + + return f"""NVVM Compilation Error: +---------------------- + +{self.BLUE}⚙️ Current Settings:{self.RESET} +{self.BOLD}- CUDA Toolkit Path: {self.cuda_toolkit or "Not Set"} +- Target Architecture: {self.arch}{self.RESET} + +IR Context (truncated): +{self.ir_context} + +{self.YELLOW}💡 Possible Solutions:{self.RESET} +{self.GREEN}1. Check if CUDA_TOOLKIT_PATH is set correctly +2. Verify target architecture ({self.arch}) is supported by your CUDA toolkit +3. Make sure CUDA toolkit version matches the target architecture requirements{self.RESET}""" + + +class Compiler: + """Compiler class for compiling and building MLIR modules.""" + + def __init__(self, passmanager, execution_engine): + self.passmanager = passmanager + self.execution_engine = execution_engine + + def __call__(self, module): + """Convenience application method.""" + self.compile(module) + + def _process_error(self, error_msg: str) -> Tuple[Optional[str], Optional[str]]: + """Process error message to extract NVVM error and IR context""" + nvvm_error = None + ir_msg = "" + + if "NVVM_ERROR" in error_msg: + # Extract the specific NVVM error + nvvm_error = ( + error_msg.split("libNVVM extra log:")[1].strip() + if "libNVVM extra log:" in error_msg + else error_msg + ) + + # Extract IR context + if "see current operation:" in error_msg: + # Get the IR section + ir_section = error_msg.split("see current operation:")[1].strip() + # Remove duplicate IR section + ir_section = ir_section.split("error: unknown: Failed translating")[ + 0 + ].strip() + + # Get first few lines and last few lines of the IR + ir_lines = ir_section.split("\n") + if len(ir_lines) > 10: + ir_msg = "\n".join(ir_lines[:5] + [" ..."] + ir_lines[-5:]) + else: + ir_msg = ir_section + + return nvvm_error, ir_msg + + def compile( + self, + module, + pipeline: str, + cuda_toolkit: str = "", + arch: str = "", + enable_verifier=False, + ): + """Compiles the module by invoking the pipeline.""" + try: + pm = self.passmanager.PassManager.parse(pipeline) + pm.enable_verifier(enable_verifier) + pm.run(module.operation) + except Exception as e: + error_msg = str(e) + nvvm_error, ir_msg = self._process_error(error_msg) + + if nvvm_error: + raise CompilationError( + error_msg, + nvvm_error=nvvm_error, + ir_context=ir_msg, + cuda_toolkit=cuda_toolkit, + arch=arch, + ) from e + raise e + + def jit(self, module, opt_level: int = 2, shared_libs: Sequence[str] = ()): + """Wraps the module in a JIT execution engine.""" + return self.execution_engine.ExecutionEngine( + module, opt_level=opt_level, shared_libs=shared_libs + ) + + def compile_and_jit( + self, + module, + pipeline: str, + shared_libs: Sequence[str] = (), + opt_level: int = 2, + cuda_toolkit: str = "", + arch: str = "", + ): + """Compiles and jits the module.""" + self.compile( + module, + pipeline, + cuda_toolkit, + arch, + ) + return self.jit(module, opt_level, shared_libs) + + +class CompileOptions: + def __init__(self, options: str = ""): + """ + This class encapsulates all compilation options relevant to function compilation. + It provides a convenient way to manage and pass compilation options, + particularly for controlling compilation settings. + By centralizing these options, it ensures consistent and flexible configuration of + compilation parameters such as optimization level, debugging control, etc. + + :param options: The options for the function. Will be parsed by argparse. + :type options: str + """ + if not isinstance(options, str): + raise DSLRuntimeError( + f"Invalid compilation `options`: {options}, it should be a string" + ) + self._parser = argparse.ArgumentParser() + self._parser.add_argument("--opt-level", nargs="?", type=int, default=3) + self._parser.add_argument( + "--enable-device-assertions", action="store_true", default=False + ) + self._parser.add_argument("--link-libraries", type=str, default="") + + try: + self._options = self._parser.parse_args(options.split()) + except SystemExit as e: + # catch argparse error and raise as DSLRuntimeError + raise DSLRuntimeError( + f"Invalid compile options: '{options}'. Please check the option values and format." + ) + log().info("`cute.compile` CompileOptions: options=" + options) + + def to_str(self): + """ + Generate a string representation of all compilation options + which will be used in pipeline options. + """ + option_strings = [] + for key, value in vars(self._options).items(): + hyphen_key = key.replace("_", "-") + if isinstance(value, bool): + formatted_value = "true" if value else "false" + else: + formatted_value = str(value) + option_strings.append(f"{hyphen_key}={formatted_value}") + + return " ".join(option_strings) + + +def compile(func, *args, **kwargs): + """ + This function is used to compile a `cute.jit` decorated function. + It will process the compile options and input parameters, do explicit compilation and return the jit executor. + + :param func: The function to compile. It can be a regular function, a method or a class instance. + :param args: The arguments to pass to the function. + :param kwargs: The keyword arguments to pass to the function. It can contain `options` like + `opt_level` to control the compilation flags. + + :return: The jit executor. + + :raises: DSLRuntimeError if the function is not decorated with `cute.jit` or is not callable. + """ + if func is None: + raise DSLRuntimeError("Function is not set or invalid.") + + if not callable(func): + raise DSLRuntimeError("Object is not callable.") + + kwargs["compile_only"] = True + kwargs["no_cache"] = True + + if inspect.isfunction(func): + # regular function + pass + elif inspect.ismethod(func): + # if it's a method, add the instance to the first argument + args = [func.__self__] + list(args) + func = func.__func__ + elif inspect.isclass(type(func)) and hasattr(func, "__call__"): + # If it's a class instance, get the class's __call__ method + args = [func] + list(args) + # Get the actual function from the class definition + func = func.__call__.__func__ + else: + raise DSLRuntimeError( + "Invalid function type, only function, method and module are supported, but got", + func, + ) + + # If it's a wrapped function created by jit decorator, get the original function + if hasattr(func, "__wrapped__"): + func = func.__wrapped__ + + if not hasattr(func, "_dsl_object"): + raise DSLRuntimeError("Function is not decorated with jit decorator.") + + # process compile options, extract the options and remove them from the kwargs + options = kwargs.pop("options", "") + func._dsl_object.compile_options = CompileOptions(options) + fcn_ptr = func._dsl_object._preprocess_and_execute(func) + return func._dsl_object._func(fcn_ptr, *args, **kwargs) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py new file mode 100644 index 0000000000000000000000000000000000000000..2b17d22b1e6d7157a7f14334b0f29f1386c58c15 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/dsl.py @@ -0,0 +1,1686 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a main DSL class for any Dialect. +The DSL should be inherited as a new class, and its initialization requires dialects. +It handles most of the mechanics for the DSL in an agnostic way, +for example, it can handle various dialect-specific tasks. +""" + + +# Standard library imports +from dataclasses import dataclass, field +import atexit +import os +import io +import sys +import errno +import ctypes +import re +import inspect +import argparse +import hashlib +from functools import lru_cache, wraps +from collections import namedtuple +from abc import ABC, abstractmethod +from typing import Any, Union, Tuple, get_origin, get_args, List +from types import FunctionType, SimpleNamespace +import warnings + +from . import typing as t +from .env_manager import EnvironmentVarManager +from .compiler import CompileOptions +from .ast_helpers import DSLOptimizationWarning + +# ============================================================================= +# CUDA Python +# ============================================================================= + +from ..base_dsl._mlir_helpers.arith import const + +# ============================================================================= +# Local module imports +# ============================================================================= + +from .cache_helpers import * +from .jit_executor import JitExecutor +from .utils.timer import timer +from .utils.logger import setup_log, log +from .utils.stacktrace import filter_exception, walk_to_top_module, filter_stackframe +from .runtime.jit_arg_adapters import is_argument_constexpr, JitArgAdapterRegistry + +from .ast_preprocessor import DSLPreprocessor +from .common import * +from .typing import ( + get_c_pointers, + get_mlir_types, +) + +# ============================================================================= +# MLIR modules +# ============================================================================= + +from .._mlir import ir +from .._mlir import runtime as rt +from .._mlir.extras import types as T +from .._mlir.dialects import arith, math, func + +# ============================================================================= +# Global Variables +# ============================================================================= + +MLIR_DYNAMIC = -9223372036854775808 + +# ============================================================================= +# Codegen Utils +# ============================================================================= + + +def _numpy_type_to_mlir_type(dtype): + if dtype == np.float64: + return T.f64() + if dtype == np.float16: + return T.f16() + if dtype == np.float32: + return T.f32() + if dtype == np.int64: + return T.i64() + if dtype == np.int32: + return T.i32() + if dtype == np.int16: + return T.i16() + if dtype == np.int8: + return T.i8() + if dtype == np.uint64: + return T.ui64() + if dtype == np.uint32: + return T.ui32() + if dtype == np.uint16: + return T.ui16() + if dtype == np.uint8: + return T.ui8() + if dtype == np.bool_: + return T.bool() + if dtype == f8E5M2: + return T.f8E5M2() + if dtype == f8E4M3FN: + return T.f8E4M3FN() + if dtype == f8E8M0FNU: + return T.f8E8M0FNU() + if dtype == f6E3M2FN: + return T.f6E3M2FN() + if dtype == f6E2M3FN: + return T.f6E2M3FN() + if dtype == f4E2M1FN: + return T.f4E2M1FN() + assert False, f"Unknown type {type}" + + +def _mlir_type_to_numpy_type(type): + if type == T.f64(): + return np.float64 + if type == T.f16(): + return np.float16 + if type == T.f32(): + return np.float32 + if type == T.i64(): + return np.int64 + if type == T.i32(): + return np.int32 + if type == T.i16(): + return np.int16 + if type == T.i8(): + return np.int8 + if type == T.ui64(): + return np.uint64 + if type == T.ui32(): + return np.uint32 + if type == T.ui16(): + return np.uint16 + if type == T.ui8(): + return np.uint8 + if type == T.bool(): + return np.bool_ + assert False, f"Unknown type {type}" + + +# ============================================================================= +# Main DSL Class +# ============================================================================= + + +def is_dynamic_expression(value): + """ + Given the `value`, check if itself is an IR value or recursively go through it to check if it contains IR value + """ + if isinstance(value, (tuple, list)): + for x in value: + if is_dynamic_expression(x): + return True + elif isinstance(value, (ir.Value, ir.BlockArgumentList)) or hasattr( + value, "__extract_mlir_values__" + ): + return True + return False + + +def extract_mlir_values(obj): + """ + Given the `obj`, recursively go through it to extract all contained IR values as list of MLIR values + """ + res = [] + if hasattr(obj, "__extract_mlir_values__"): + res = obj.__extract_mlir_values__() + elif isinstance(obj, (tuple, list)): + res = sum((extract_mlir_values(x) for x in obj), []) + elif isinstance(obj, SimpleNamespace): + res = [] + for k, v in obj.__dict__.items(): + res.extend(extract_mlir_values(v)) + # Can't call is_dynamic_expression as _is_dynamic_expression depends on extract_mlir_values + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in extract_mlir_values to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + elif isinstance(obj, ir.Value): + res = [obj] + elif isinstance(obj, ir.BlockArgumentList): + res = list(obj) # type: ignore + + return res + + +def new_from_mlir_values(obj, values): + """ + Create a new python object by populating containing MLIR values with list of new values + """ + if hasattr(obj, "__new_from_mlir_values__"): + return obj.__new_from_mlir_values__(values) + elif isinstance(obj, (tuple, list)): + res = [] + for x in obj: + n_items = len(get_mlir_types(x)) + res.append(new_from_mlir_values(x, values[:n_items])) + values = values[n_items:] + obj_ty = type(obj) + return obj_ty(res) + elif isinstance(obj, SimpleNamespace): + res = SimpleNamespace() + for k, v in obj.__dict__.items(): + n_items = len(get_mlir_types(v)) + res.__dict__[k] = new_from_mlir_values(v, values[:n_items]) + values = values[n_items:] + return res + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in new_from_mlir_values to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + elif is_dynamic_expression(obj): + + if len(values) == 0: + return obj + + assert len(values) == 1 + return values[0] + else: + assert len(values) == 0, f"{obj} expects 0 values, but got {values}" + return obj + + +class DSLCallable: + """ + Wrapper class for a callable object used within the DSL. + + DSLCallable is designed to wrap a function and provide additional + introspection utilities such as retrieving the argument specification + and signature. It ensures that the wrapped function can only be called + once, after which the reference to the function is cleared to prevent + further invocations. This is useful in scenarios where a function should + only be executed a single time within the DSL's execution model. + + Attributes: + func (callable): The function to be wrapped and managed. + + Methods: + __call__(*args, **kwargs): Calls the wrapped function and clears it. + """ + + def __init__(self, func): + self.func = func + + def __call__(self, *args, **kwargs): + ret = self.__func__(*args, **kwargs) + self.func = None + return ret + + @property + def __func__(self): + assert self.func is not None, "DSLCallable is already called" + return self.func + + @property + def __signature__(self): + return inspect.signature(self.__func__) + + @property + def __name__(self): + return self.__func__.__name__ + + +class BaseDSL: + gpu_module = None + + def __init__( + self, + *, + name: str, + dsl_package_name: List[str], + compiler_provider: Any, + pass_sm_arch_name: str, + device_compilation_only=False, + preprocess=False, + ): + """ + Constructor for initializing the class with required providers and environment settings. + + Parameters: + - name (str): Name of DSL, used for environment variables and logging. + - package_name (str): Name of the package, used for the preprocessor. + - compiler_provider (MLIR dialect): Provider for compiler. + - pass_sm_arch_name (str): The keyword name of the SM. + - device_compilation_only (bool) : Only device code, and call it via cuda driver + - preprocess (bool): Enable AST transformation. + + This constructs a DSL instance and sets up environment management, + warning configurations, and logging functionalities. It reads + environment variables using `EnvironmentVarManager` and configures + a logger with settings from the environment. If environment warnings + are detected, they are escalated to errors to ensure strict handling. + """ + # Enforcing initialization of instance variables + if not all([name, compiler_provider, pass_sm_arch_name]): + raise DSLRuntimeError( + "All required parameters must be provided and non-empty" + ) + + self.name = name + self.compiler_provider = compiler_provider + self.pass_sm_arch_name = pass_sm_arch_name + self.frame = None + self.no_cache = False + self.device_compilation_only = device_compilation_only + self.num_kernels = 0 + # Read environment variables + self.envar = EnvironmentVarManager(self.name) + self.enable_preprocessor = preprocess + # This cache uses hash of original ir and env as key, allows dump/load to/from file. Enabled by default + self.jit_cache = ( + dict() + if self.envar.disable_file_caching + else load_cache_from_path(self.name, self.envar.file_caching_capacity) + ) + self.host_jit_decorator_name = f"@{BaseDSL.jit.__name__}" + self.device_jit_decorator_name = f"@{BaseDSL.kernel.__name__}" + + # set warning + if not self.envar.enable_optimization_warnings: + # By default, optimization warnings are disabled + warnings.filterwarnings("ignore", category=DSLOptimizationWarning) + if self.envar.warnings_as_errors: + warnings.filterwarnings("error") + if self.envar.warnings_ignore: + warnings.filterwarnings("ignore") + + # Initialize logger + if self.envar.log_to_console == False and self.envar.jitTimeProfiling: + self.envar.log_to_console = True + self.envar.log_level = 20 # info level + setup_log( + self.name, + self.envar.log_to_console, + self.envar.log_to_file, + f"{self.name}.log", + self.envar.log_level, + ) + + # kernel symbols are temporary symbol string variables, their values are valid until the compilation is done. + self.kernel_symbols = [] + # used to generate unique name for gpu.launch + self.launch_inner_count = 0 + # initialize default compile options + self.compile_options = CompileOptions() + + if preprocess: + self.preprocessor = DSLPreprocessor(dsl_package_name) + log().info(f"Initializing {name} DSL") + log().debug(f"Logger initialized for {self.name}") + + # Hook excepthook + if self.envar.filterStacktrace: + origin_excepthook = sys.excepthook + module_dir = walk_to_top_module(os.path.dirname(os.path.abspath(__file__))) + + def excepthook(excep_type, value, traceback): + filter_exception(value, module_dir) + if hasattr(value, "__traceback__"): + origin_excepthook(excep_type, value, value.__traceback__) + else: + origin_excepthook( + excep_type, value, filter_stackframe(traceback, module_dir) + ) + + sys.excepthook = excepthook + + # Restore original excepthook + def restore_excepthook(hook): + sys.excepthook = hook + + atexit.register(restore_excepthook, origin_excepthook) + + def dump_cache(self): + if not self.envar.disable_file_caching: + dump_cache_to_path( + self.name, self.jit_cache, self.envar.file_caching_capacity + ) + + @lru_cache(maxsize=1) + def print_warning_once(self, message): + log().warning(f"Warning: {message}") + warnings.warn(message, UserWarning) + + def print_warning(self, message): + log().warning(f"Warning: {message}") + warnings.warn(message, UserWarning) + + @classmethod + @lru_cache(maxsize=1) + def _get_dsl(cls): + # Instantiate the DSL Class once + main_dsl = cls() + if not main_dsl.no_cache: + # register atexit callback + atexit.register(main_dsl.dump_cache) + return main_dsl + + @staticmethod + def _can_preprocess(**dkwargs): + """ + Check if AST transformation is enabled or not for `jit` and `kernel` decorators. + """ + return dkwargs.pop("preprocess", True) + + @staticmethod + def _get_original_function(fcn_ptr, name): + """ + Get the original function from the decorated function + """ + while fcn_ptr.__name__ != name: + # If the function is wrapped with functools, get from __wrapped__ + if hasattr(fcn_ptr, "__wrapped__"): + fcn_ptr = fcn_ptr.__wrapped__ + # If the function is wrapped manually, it's the first in clousure + elif callable(fcn_ptr.__closure__[0].cell_contents): + fcn_ptr = fcn_ptr.__closure__[0].cell_contents + else: + raise DSLRuntimeError( + f"Cannot find the original function {name} in the closure chain" + ) + return fcn_ptr + + @staticmethod + def _preprocess_and_execute(func): + """ + Run ast transformation and return the materialized function pointer + """ + if hasattr(func, "_transformed_ast"): + # If the function ptr is already materialized, use the existing one + func._dsl_object.frame = func._decorator_frame + if func._transformed_ast is None: + func._transformed_ast = func._dsl_object.run_preprocessor(func) + if func._transformed_ast is None: + del func._transformed_ast + func._dsl_object.frame = None + return func + + fcn_ptr = func._dsl_object.get_function_ptr(func) + # If the function is decorated, de-decorate it + fcn_ptr = BaseDSL._get_original_function(fcn_ptr, func.__name__) + func._dsl_object.frame = None + return DSLCallable(fcn_ptr) + return func + + def jit_runner(self, executor, frame, *dargs, **dkwargs): + """ + Decorator to mark a function for JIT compilation. + """ + log().info("jit_runner") + + def jit_runner_decorator(func): + func._dsl_object = self + # Run preprocessor that alters AST + if self.enable_preprocessor and BaseDSL._can_preprocess(**dkwargs): + # For an annotated function, add some DSL attributes + # When materializing the AST, we need decorator's frame + func._decorator_frame = frame + # No transformed ast at this point + func._transformed_ast = None + + @wraps(func) + def jit_wrapper(*args, **kwargs): + func_ptr = BaseDSL._preprocess_and_execute(func) + return executor(func_ptr, *args, **kwargs) + + return jit_wrapper + + if len(dargs) == 1 and callable(dargs[0]): + return jit_runner_decorator(dargs[0]) + else: + return jit_runner_decorator + + @classmethod + def jit(cls, *dargs, **dkwargs): + """ + Decorator to mark a function for JIT compilation for Host code. + """ + frame = inspect.currentframe().f_back + # Instantiate the DSL Class + main_dsl = cls._get_dsl() + return main_dsl.jit_runner(main_dsl._func, frame, *dargs, **dkwargs) + + @classmethod + def kernel(cls, *dargs, **dkwargs): + """ + Decorator to mark a function for JIT compilation for GPU. + """ + frame = inspect.currentframe().f_back + # Instantiate the DSL Class + main_dsl = cls._get_dsl() + return main_dsl.jit_runner(main_dsl._kernel_helper, frame, *dargs, **dkwargs) + + @abstractmethod + def _kernel_helper(self, func, *args, **kwargs): + """ + Helper function to handle kernel generation logic + """ + pass + + @abstractmethod + def _build_gpu_module(self, attrs): + """ + Build the module op that contains the kernels. + """ + pass + + @abstractmethod + def _get_pipeline(self, pipeline): + """ + Get the pipeline from the other configuration options. + """ + if pipeline != None: + return pipeline + return None + + @staticmethod + def log_additions(func_type, operands=None, types=None, arg_attrs=None): + if operands is not None and operands != []: + log().debug( + f"Added {func_type} operands: [%s]", ", ".join(map(str, operands)) + ) + if types is not None: + log().debug( + f"Added {func_type} arg_types: [%s]", ", ".join(map(str, types)) + ) + if arg_attrs is not None: + log().debug( + f"Added {func_type} arg_attrs: [%s]", ", ".join(map(str, arg_attrs)) + ) + + def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): + """Does simple name mangling""" + + for spec_arg, arg in zip(args_spec.args, args): + spec_ty = args_spec.annotations.get(spec_arg, None) + if spec_ty != None: + if issubclass(type(spec_ty), (t.IRValue, t.IRVariadic)): + continue + if isinstance(spec_ty, (ir.Type, ir.Value)): + continue + if isinstance(arg, (ir.Type, ir.Value, ir.OpResult)): + continue + if isinstance(type(arg), (ir.Type, ir.Value, ir.OpResult)): + continue + if self._is_tensor_descriptor(arg): + continue + if inspect.isclass(spec_ty): + class_name = str(arg).replace("class", "") + class_name = class_name.replace(" ", "") + function_name = f"{function_name}_{class_name}" + elif isinstance(arg, (list, tuple)): + function_name = f"{function_name}_{'_'.join(map(str, arg))}" + else: + function_name = f"{function_name}_{arg}" + # we would need a dedicated MR to follow up + unwanted_chars = r"'-![]#,.<>()\":{}=%?@;" + translation_table = str.maketrans("", "", unwanted_chars) + function_name = function_name.translate(translation_table) + # identify address and drop + function_name = re.sub(r"0x[a-f0-9]{8,16}", "", function_name) + function_name = re.sub(r"\s+", " ", function_name) + function_name = function_name.replace(" ", "_") + function_name = function_name.replace("\n", "_") + # max fname is 256 character, leave space + function_name = function_name[:180] + log().info(f"Final mangled function name: {function_name}") + return function_name + + def _generate_execution_arguments_for_known_types( + self, arg, arg_spec, arg_name, i, fop_args, iv_block_args + ): + """ + Generate MLIR arguments for known types. + + Sub-DSLs can override this method to handle types that are not + natively supported by the Base DSL. + """ + ir_arg = [] + if is_argument_constexpr(arg, arg_spec, arg_name, i, func): + ir_arg.append(arg) + + return ir_arg, iv_block_args + + def generate_execution_arguments( + self, + args, + kwargs, + fop, + args_spec: inspect.FullArgSpec, + ): + """Create list of arguments that will be passed to MLIR's func.func op""" + + def gen_exec_args(input_args, arg_names, annotations, fop_args): + assert len(input_args) == len(arg_names) + + ir_args = [] + iv_block_args = 0 + for i, arg in enumerate(input_args): + arg_name = arg_names[i] + arg_spec = annotations.get(arg_name, None) + log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, arg_spec) + + # Implicit cast to NumericMeta + if isinstance(arg_spec, t.NumericMeta) and not isinstance( + arg, arg_spec + ): + arg = t.cast(arg, arg_spec) + + ir_arg, iv_block_args = ( + self._generate_execution_arguments_for_known_types( + arg, arg_spec, arg_name, i, fop_args, iv_block_args + ) + ) + + if not ir_arg: + # If it's not a known type, try JIT argument adapter + # to convert the argument if possible + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + arg = adapter(arg) if adapter else arg + + n_args = len(get_mlir_types(arg)) + blk_args = fop_args[iv_block_args : iv_block_args + n_args] + ir_arg.append(new_from_mlir_values(arg, blk_args)) + iv_block_args += n_args + + self.log_additions(ir_arg) + ir_args.extend(ir_arg) + + return ir_args, iv_block_args + + fop_args = list(fop.regions[0].blocks[0].arguments) + ir_args, iv_block_args = gen_exec_args( + args, args_spec.args, args_spec.annotations, fop_args + ) + ir_kwargs, _ = gen_exec_args( + [kwargs[arg] for arg in args_spec.kwonlyargs], + args_spec.kwonlyargs, + args_spec.annotations, + fop_args[iv_block_args:], + ) + ir_kwargs = {k: v for k, v in zip(args_spec.kwonlyargs, ir_kwargs)} + + log().debug("execution args: %s", ", ".join(map(str, ir_args))) + log().debug("execution kwargs: %s", ", ".join(map(str, ir_kwargs))) + return ir_args, ir_kwargs + + @abstractmethod + def _generate_mlir_type_for_tensor_descriptor(self, tensor): + """ + Generate MLIR type for the tensor descriptor. + """ + pass + + @abstractmethod + def _generate_executable_arg_for_tensor_descriptor( + self, mlir_value=None, ptr_tensor_ty=None, tensor=None + ): + """ + Generates executable value for the given tensor descriptor. + """ + pass + + def _get_globals(self): + """ + Combines global and local variables from the current context and the + caller's frame comes. This includes the current module's globals, the + global variables from the caller's frame, and the local variables from + the caller's frame. + + "self.frame" is used to fetch the caller's frame. + + AST preprocessor generates a new python code, so the resulting globals + dictionary is used to execute the python code. + """ + all_globals = {} + if self.frame: + all_globals.update(self.frame.f_globals) + all_globals.update(self.frame.f_locals) + return all_globals + + @abstractmethod + def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: + pass + + @abstractmethod + def _handle_tensor_descriptor( + self, maybe_tensor, arg_name: str, need_gpu_memory: bool + ) -> Any: + pass + + def _validate_arg(self, arg, arg_index, arg_name, arg_spec): + """ + Validates if the arg is really of the annotated type for type safety. + + The default implementation is empty. Subclasses can override this method to add more validation logic. + Returns None if validation passes, otherwise returns an error derived from DSLBaseError. + """ + pass + + def _generate_jit_func_args_for_known_types( + self, + func, + arg, + arg_name, + arg_spec, + arg_index, + *, + is_host=True, + ): + """ + Generate JIT function arguments for known types. + + Sub-DSLs can override this method to handle types that are not + natively supported by the Base DSL. + """ + + jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] + default_attr = ir.DictAttr.get({}) + + if is_argument_constexpr(arg, arg_spec, arg_name, arg_index, func): + jit_exec_arg = jit_arg_type = jit_arg_attr = None + + return jit_exec_arg, jit_arg_type, jit_arg_attr + + def _generate_jit_func_args( + self, + func, + function_name, + args, + kwargs, + args_spec: inspect.FullArgSpec, + *, + is_host=True, + ): + """Generate JIT function arguments.""" + + assert len(args) == len(args_spec.args) and len(kwargs) == len( + args_spec.kwonlyargs + ), ( + f"Input args {len(args)=} and kwargs {len(kwargs)=} must match arg_spec.args " + f"{len(args_spec.args)=} and arg_spec.kwonlyargs {len(args_spec.kwonlyargs)=}" + ) + + jit_arg_types, jit_arg_attrs, jit_exec_args = [], [], [] + jit_adapted_args = [] + default_attr = ir.DictAttr.get({}) + + input_args = [*args, *kwargs.values()] + input_arg_names = [*args_spec.args, *args_spec.kwonlyargs] + for i, (arg_name, arg) in enumerate(zip(input_arg_names, input_args)): + spec_ty = args_spec.annotations.get(arg_name, None) + log().debug("Processing [%d] Argument [%s : %s]", i, arg_name, spec_ty) + + # Implicitly convert into Numeric type if possible + if isinstance(spec_ty, t.NumericMeta) and not isinstance(arg, spec_ty): + arg = t.cast(arg, spec_ty) + + # Type safety check + if spec_ty is not None: + err = self._validate_arg(arg, i, arg_name, spec_ty) + if err is not None: + raise err + + jit_exec_arg, jit_arg_type, jit_arg_attr = ( + self._generate_jit_func_args_for_known_types( + func, + arg, + arg_name, + spec_ty, + i, + is_host=is_host, + ) + ) + + if jit_arg_type is not None and len(jit_arg_type) == 0: + # If not any known type, try JIT argument adapter + # to convert the argument + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + if adapter: + arg = adapter(arg) + jit_adapted_args.append(arg) + + if is_host: + jit_exec_arg.extend(get_c_pointers(arg)) + jit_arg_type.extend(get_mlir_types(arg)) + else: + dyn_vals = extract_mlir_values(arg) + jit_exec_arg.extend(dyn_vals) + jit_arg_type.extend([v.type for v in dyn_vals]) + + if not jit_arg_type or not jit_exec_arg: + if (is_host and hasattr(arg, "__c_pointers__")) or ( + not is_host + and hasattr(arg, "__extract_mlir_values__") + and hasattr(arg, "__new_from_mlir_values__") + ): + pass + else: + raise DSLRuntimeError( + f"failed to generate argument #{i+1} ({arg_name}) for JIT function '{function_name}'.", + context={ + f"Argument {arg_name}": "The DSL attempted to convert it into Dynamic Expression (aka MLIR values) but failed.", + f"Call-site argument value": arg, + f"Call-site argument type": type(arg), + }, + suggestion=f"Consider annotating the argument with `{arg_name} : Constexpr` " + "if it's a value known at compile-time. " + f"Otherwise, implement the {'`JitArgument`' if is_host else '`DynamicExpression`'} " + f"protocol or register a custom JIT argument adapter for type `{type(arg)}` to " + "enable dynamic value conversion at runtime.", + ) + + jit_arg_attr.extend([default_attr] * len(jit_arg_type)) + + if jit_arg_type is not None: + jit_exec_args.extend(jit_exec_arg) + jit_arg_types.extend(jit_arg_type) + jit_arg_attrs.extend(jit_arg_attr) + + return jit_exec_args, jit_arg_types, jit_arg_attrs, jit_adapted_args + + def generate_mlir_function_types( + self, func, function_name, input_args, kwargs, args_spec: inspect.FullArgSpec + ): + """Convert input arguments to MLIR function signature also convert numpy arrays to memref.""" + + exe_args, types, attrs, adapted_args = self._generate_jit_func_args( + func, function_name, input_args, kwargs, args_spec, is_host=True + ) + + log().debug("Execution Arguments: %s", ", ".join(map(str, exe_args))) + log().debug("Types: %s", ", ".join(map(str, types))) + + assert len(exe_args) == len( + types + ), "expects the same number of arguments and function parameters" + + return exe_args, types, adapted_args + + @dataclass + class LaunchConfig: + cluster: list = None + grid: list = field(default_factory=lambda: [1, 1, 1]) + block: list = field(default_factory=lambda: [1, 1, 1]) + smem: int = None + async_deps: list = field(default_factory=list) + has_cluster: bool = False + min_blocks_per_mp: int = 0 + auto_smem: bool = False + + def __post_init__(self): + if len(self.grid) != 3: + raise DSLRuntimeError(f"Expect 3d grid!") + if len(self.block) != 3: + raise DSLRuntimeError(f"Expect 3d block!") + + if self.smem is None: + self.smem = 0 + self.auto_smem = True + + self.has_cluster = self.cluster is not None + if self.cluster is None: + self.cluster = [None, None, None] + elif len(self.cluster) != 3: + raise DSLRuntimeError(f"Expect 3d cluster!") + + def diagnostic(self): + """Check command line parameters and enables diagnostic""" + # Check command line arguments "-diagnostic" + parser = argparse.ArgumentParser(description="Process diagnostic status.") + parser.add_argument( + "-diagnostic", + nargs="?", + const="all", + choices=["all", "fail", "success", "info", "suggestion"], + help="Set diagnostic status (fail, success, info, suggestion).", + ) + + args, _ = parser.parse_known_args() + ctx = ir.Context.current + + def callback(d): + print(f" [{self.name} Diagnostic] : {d.message}") + + ctx.attach_diagnostic_handler(callback) + + # Early return, don't enable diagnostics + if args.diagnostic is None: + return + + # Enable MLIR Flags + ctx.emit_error_diagnostics = True + ir._GlobalDebug.flag = True + if args.diagnostic == "all": + ir._GlobalDebug.set_types("diagnostic") + else: + ir._GlobalDebug.set_types(f"diagnostic-{args.diagnostic}") + + def get_location(self): + """ + Get python location information and generate MLIR location + """ + + if self.frame is None: + log().debug("Frame is None") + return None + + file_loc = ir.Location.file( + self.frame.f_code.co_filename, self.frame.f_lineno, 0 + ) + + loc = ir.Location.name(self.frame.f_code.co_name, childLoc=file_loc) + return loc + + def compile_and_jit(self, module, pipeline, shared_libs, function_name=""): + """ + Compile and JIT an MLIR module. + """ + + try: + self.diagnostic() + + orig_stdout = sys.stdout + orig_stderr = sys.stderr + sys.stderr = redirect_stderr = io.StringIO() + sys.stdout = redirect_stdout = io.StringIO() + + try: + kernel = self.compiler_provider.compile_and_jit( + module, + pipeline, + shared_libs=shared_libs, + cuda_toolkit=self.envar.cuda_toolkit, + arch=self.envar.arch, + ) + + finally: + sys.stdout = orig_stdout + sys.stderr = orig_stderr + ir._GlobalDebug.flag = False + + # Print captured output. + print(redirect_stdout.getvalue(), file=sys.stdout, end="") + print(redirect_stderr.getvalue(), file=sys.stderr, end="") + + return kernel + + except Exception as e: + raise DSLRuntimeError("🧊🧊🧊 ICE 🧊🧊🧊", cause=e) + finally: + pass + + def preprocess_pipeline(self, pipeline, arch) -> str: + + if self.envar.cuda_toolkit is None: + self.print_warning( + "CUDA_TOOLKIT_PATH environment variable is not set. Cannot set toolkitPath." + ) + + options = { + "toolkitPath": self.envar.cuda_toolkit if self.envar.cuda_toolkit else None, + self.pass_sm_arch_name: arch, + } + + opt_str = "" + for k, v in options.items(): + if v: + opt_str += f"{k}={v} " + + if opt_str: + # Automatically append the pipeline options if any is specified through env var + pattern = re.compile(r"{(.+)}") + match = pattern.search(pipeline) + if match: + opt_str = f"{{{match[1]} {opt_str}}}" + pipeline = re.sub(r"{.+}", opt_str, pipeline) + else: + pipeline = pipeline.rstrip(")") + f"{{{opt_str}}})" + log().debug(f"Using pipeline = {pipeline}") + return pipeline + + def get_shared_libs(self) -> list: + shared_libs = [] + support_libs = self.envar.shared_libs + if support_libs is not None: + _libs = support_libs.split(":") + for lib in _libs: + if not os.path.exists(lib): + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), lib + ) + shared_libs.append(lib) + else: + self.print_warning(f"{self.name}_LIBS environment variable is not set") + + return shared_libs + + @lru_cache(maxsize=1) + def get_version(self): + version_hash = hashlib.sha256() + + return version_hash + + def get_module_hash(self, module, function_name): + s = io.BytesIO() + module.operation.write_bytecode(s) + for attr, value in self.envar.__dict__.items(): + if value is not None: + s.write(str(value).encode()) + # Add compile options to the hash + s.write(self.compile_options.to_str().encode()) + module_hash = self.get_version().copy() + module_hash.update(s.getvalue()) + module_hash = module_hash.hexdigest() + + log().debug("Bytecode=[%s]", s.getvalue().hex()) + log().debug("Version=[%s]", self.get_version().hexdigest()) + log().info( + "Function=[%s] Computed module_hash=[%s]", function_name, module_hash + ) + return module_hash + + def build_module(self, module, function_name: str): + """ + Build the MLIR module, verify and return the module + """ + + # Save IR in a file + if self.envar.keepIR: + save_ir(self.name, module, function_name) + + if self.envar.printIR: + print("\n//===--- ------ Generated IR ------ ---====\n") + module.operation.print( + enable_debug_info=self.envar.generate_source_location + ) + print("\n//===--- --- End of Generated IR -- ---====\n") + + # Verify the module + try: + module.operation.verify() + except Exception as e: + raise DSLRuntimeError(f"🧊🧊🧊 ICE IR Verification Failed 🧊🧊🧊", cause=e) + + return module + + def generate_original_ir( + self, + ir, + func, + funcBody, + kwargs, + function_name, + func_types, + gpu_module_attrs, + args, + args_spec, + ): + # This location is set to None for now; otherwise, calls to the same + # function on different lines would produce different line numbers, + # which would break the cache. + loc = None # self.get_location() + + def build_ir_module(): + module = ir.Module.create(loc=loc) + unit_attr = ir.UnitAttr.get() + module.operation.attributes["gpu.container_module"] = unit_attr + + with ir.InsertionPoint(module.body): + # Always generate gpu module. It's canonicalized by the compiler when it's not used. + self._build_gpu_module(gpu_module_attrs) + + fop = func.FuncOp(function_name, (func_types, []), loc=loc) + fop.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + log().debug("Generated Function OP [%s]", fop) + with ir.InsertionPoint(fop.add_entry_block()): + ir_args, ir_kwargs = self.generate_execution_arguments( + args, kwargs, fop, args_spec + ) + # Call user function body + try: + result = funcBody(*ir_args, **ir_kwargs) + func.ReturnOp([]) + except NameError as name_error: + raise DSLRuntimeError( + f"💥💥💥 Error during runtime code generation for function `{funcBody.__name__}` 💥💥💥", + cause=name_error, + suggestion="Using variables defined in dynamic control flow is not supported. Please give an initial value before control flow.", + ) + except DSLRuntimeError as dsl_error: + # Throw it's already a DSL error + raise dsl_error + return module, result + + # Build IR module + profiler = timer(enable=self.envar.jitTimeProfiling) + module, result = profiler(build_ir_module)() + module_hash = self.get_module_hash(module, function_name) + + module = self.build_module(module, function_name) + + return module, module_hash, result + + def compile_and_cache( + self, module, module_hash, function_name, pipeline, args_spec, no_cache + ): + arch = self.envar.arch + pipeline = self.preprocess_pipeline(self._get_pipeline(pipeline), arch) + shared_libs = self.get_shared_libs() + profiler = timer(enable=self.envar.jitTimeProfiling) + if ( + no_cache + or module_hash not in self.jit_cache + or self.jit_cache[module_hash].ir_module is None + ): + log().info( + "JIT cache miss function=[%s] module_hash=[%s]", + function_name, + module_hash, + ) + # Compile and JIT MLIR module + engine = profiler(self.compile_and_jit)( + module, pipeline, shared_libs, function_name=function_name + ) + else: + log().info( + "JIT cache hit IN-FILE function=[%s] module_hash=[%s]", + function_name, + module_hash, + ) + module = self.jit_cache[module_hash].ir_module + engine = self.compiler_provider.jit(module, shared_libs=shared_libs) + capi_func = profiler(engine.lookup)(function_name) + jit_executor = JitExecutor( + self, + engine, + capi_func, + module, + args_spec, + function_name, + jit_time_profiling=self.envar.jitTimeProfiling, + ) + jit_executor = jit_executor.update_jit_cuda_modules(self.kernel_symbols) + + if not no_cache: + # module stored in cache is compiled. + self.jit_cache[module_hash] = jit_executor + + return jit_executor + + def post_compilation_cleanup(self): + """Clean up some internal state after one compilation is completed.""" + # clear the kernel symbols after the compilation is done. + self.kernel_symbols = [] + self.launch_inner_count = 0 + # reset num_kernels to 0 for next compilation. + self.num_kernels = 0 + # reset the compile options after the compilation is done. + self.compile_options = CompileOptions() + + def generate_mlir( + self, + funcBody, + kwargs, + function_name, + gpu_module_attrs, + args, + args_spec, + pipeline, + no_cache, + compile_only, + loc=None, + ): + """Generate MLIR module and compile iself.T_provider.""" + with ir.Context(), ir.Location.unknown(): + # Convert input arguments to MLIR arguments + exe_args, func_types, adapted_args = self.generate_mlir_function_types( + funcBody, function_name, args, kwargs, args_spec + ) + + # Generate original ir module and its hash value. + module, module_hash, result = self.generate_original_ir( + ir, + func, + funcBody, + kwargs, + function_name, + func_types, + gpu_module_attrs, + args, + args_spec, + ) + + # dryrun is used to only generate IR + if self.envar.dryrun: + return result + + if ( + no_cache + or module_hash not in self.jit_cache + or self.jit_cache[module_hash].capi_func is None + ): + # no cache or cache miss, do ir generation/compilation/jit engine + jit_executor = self.compile_and_cache( + module, module_hash, function_name, pipeline, args_spec, no_cache + ) + else: + # cache hit + log().info( + "JIT cache hit IN-MEMORY function=[%s] module_hash=[%s]", + function_name, + module_hash, + ) + jit_executor = self.jit_cache[module_hash] + + self.post_compilation_cleanup() + # If compile_only is set, bypass execution return the jit_executor directly + if compile_only: + return jit_executor + # Run the compiled program + jit_executor.run_compiled_program(exe_args) + + return result + + def run_preprocessor(self, funcBody): + if not hasattr(funcBody, "_preprocessed"): + function_name = funcBody.__name__ + self.funcBody = funcBody + log().info("Started preprocessing [%s]", function_name) + exec_globals = self._get_globals() + transformed_ast = self.preprocessor.transform(funcBody, exec_globals) + if self.envar.print_after_preprocessor: + log().info( + f"# Printing unparsed AST after preprocess of func=`{function_name}` id=`{id(funcBody)}`" + ) + DSLPreprocessor.print_ast(transformed_ast) + funcBody._preprocessed = True + return transformed_ast + return None + + def get_function_ptr(self, original_function): + file_name = inspect.getsourcefile(original_function) + code_object = compile( + original_function._transformed_ast, filename=file_name, mode="exec" + ) + return self.preprocessor.exec( + original_function.__name__, + original_function, + code_object, + self._get_globals(), + ) + + def _get_function_bound_args(self, sig, func_name, *args, **kwargs): + """ + Binds provided arguments to a function's signature and applies default values. + + E.g. given a function signature `def foo(a, b=2, c=3)`, and at call-site if we do + `foo(a=1, c=4)`, the returned BoundArguments object will have args = `[1]` + and kwargs = `{'b': 2, 'c': 4}` + + An exception will be raised if binding fails. + """ + try: + bound_args = sig.bind_partial(*args, **kwargs) + bound_args.apply_defaults() + except Exception as e: + raise DSLRuntimeError( + f"Failed to bind arguments to function `{func_name}` with signature `{sig}`", + cause=e, + ) + return bound_args + + def _canonicalize_args(self, sig, *args, **kwargs): + """ + Canonicalize the input arguments so that returned args only contain + positional arguments and kwargs only contain keyword arguments. + """ + function_name = self.funcBody.__name__ + bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) + canonicalized_args = bound_args.args + canonicalized_kwargs = bound_args.kwargs + return canonicalized_args, canonicalized_kwargs + + def _check_arg_count(self, *args, **kwargs): + if not self.funcBody: + raise DSLRuntimeError("Function body is not set.") + + # Pass the actual function object to inspect.signature to get the signature. + sig = inspect.signature(self.funcBody) + + function_name = self.funcBody.__name__ + + bound_args = self._get_function_bound_args(sig, function_name, *args, **kwargs) + + # Check if all non-default arguments are provided + for param in sig.parameters.values(): + if ( + param.default is inspect.Parameter.empty + and param.name not in bound_args.arguments + ): + raise DSLRuntimeError( + f"Missing required argument in `{function_name}`: '{param.name}'" + ) + + return sig + + def _func(self, funcBody, *args, **kwargs): + """Decorator for MLIR functions. + It cuts the boilerplate code, does the following: + 1. Generates `func.func` + 2. Types translation (numpy arrays -> cute.memref, float -> , etc.) + 3. Compiles and JITs the MLIR module + 4. Invokes the generated function + 5. Operator overloading (a + b --> arith.addi a, b) + 6. Generates GPU kernel function with GPU module and kernel attributes baked + """ + if ir.Context.current is None: + pass + elif ir.InsertionPoint.current is not None: + return funcBody(*args, **kwargs) + + function_name = funcBody.__name__ + self.funcBody = funcBody + + pipeline = kwargs.pop("pipeline", None) + gpu_module_attrs = kwargs.pop("gpu_module_attrs", {}) + + # Disable cache + no_cache = kwargs.pop("no_cache", False) + + # Always compile(disable cache) and return the result jit_executor + compile_only = kwargs.pop("compile_only", False) + + if not no_cache and compile_only: + no_cache = True + self.print_warning("Cache is disabled as user wants to compile only.") + + # Check the number of arguments + sig = self._check_arg_count(*args, **kwargs) + + args_spec = inspect.getfullargspec(funcBody) + + # Canonicalize the input arguments + canonicalized_args, canonicalized_kwargs = self._canonicalize_args( + sig, *args, **kwargs + ) + + # Simple name mangling + function_name = self.mangle_name(function_name, canonicalized_args, args_spec) + + # Generate MLIR Context and start generating IR + log().debug(f"Generating MLIR for function '{function_name}'") + result = self.generate_mlir( + funcBody, + canonicalized_kwargs, + function_name, + gpu_module_attrs, + canonicalized_args, + args_spec, + pipeline, + no_cache, + compile_only, + ) + + return result + + class _KernelGenHelper(ABC): + def __init__(self): + self.func_op = None + self.func_type = None + + @abstractmethod + def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): + assert arg_types is not None, "Invalid arg_types!" + assert kernel_name is not None, "kernel name is empty" + pass + + @abstractmethod + def generate_func_ret_op(self): + pass + + @abstractmethod + def generate_launch_op(self, *args, **kwargs): + pass + + @abstractmethod + def get_func_body_start(self): + pass + + @abstractmethod + def enter_gpu_module(module): + """Compute the insertion point into the given module.""" + pass + + @lru_cache(maxsize=1) + def _get_default_stream(self): + """Returns the default stream 0""" + from .runtime import cuda as cuda_helpers + + return cuda_helpers.stream_create() + + def _execute_cuda( + self, fname_cubin, kernel_name, grid_size, block_size, smem_size, stream=None + ): + """ + Executes a specified CUDA kernel from a cubin file, handling module loading, + kernel retrieval, stream creation, kernel launch, and synchronization. + """ + from .runtime import cuda as cuda_helpers + + # Step 1. Load CUDA Module + module = cuda_helpers.load_cubin_module(fname_cubin) + # Step 2. Find CUDA function + kernel_ptr = cuda_helpers.get_kernel_function(module, kernel_name) + + sync_execution_default = False + if stream is None: + stream = self._get_default_stream() + sync_execution_default = True + + # Step 4. Launch the kernel + cuda_helpers.launch_kernel( + kernel_ptr, + grid_size, + block_size, + stream, + smem_size=smem_size, + kernel_args=self.exe_args, + ) + + if sync_execution_default: + # Step 5. Optional Sync cuda stream + cuda_helpers.stream_sync(stream) + + def _execute_by_cuda_driver( + self, + kernel_generator, + generate_cubin, + grid_size, + block_size, + smem_size, + stream=None, + ): + """ + This function builds IR and execute the module using cuda driver. + It doesn't use mlir's cuda runtime + """ + ret = None + + # Step 1. Build IR + with ir.Context(), ir.Location.unknown(): + loc = self.get_location() + module = ir.Module.create(loc=loc) + unit_attr = ir.UnitAttr.get() + module.operation.attributes["gpu.container_module"] = unit_attr + with ir.InsertionPoint(module.body): + self._build_gpu_module() + ret, kernel_name = kernel_generator() + log().debug( + f"Kernel generator returned: ret={ret}, kernel_name={kernel_name}" + ) + + module = self.build_module(module, kernel_name) + + # dryrun is used to only generate IR + if self.envar.dryrun: + return ret + + # Generate cubin + fname_cubin = generate_cubin(module, kernel_name) + + # Execute a cuda kernel from cubin + self._execute_cuda( + fname_cubin, kernel_name, grid_size, block_size, smem_size, stream + ) + + return ret + + def generate_kernel_operands_and_types( + self, kernel_func, kernel_name, args_spec, args, kwargs + ): + """ + Generate the operands and types for the kernel function + """ + + kernel_operands, kernel_arg_types, kernel_arg_attrs = [], [], [] + + log().debug( + "Processing GPU kernel call in [%s] mode", + ( + f"Only {self.device_jit_decorator_name}" + if self.device_compilation_only + else f"{self.host_jit_decorator_name} + {self.device_jit_decorator_name}" + ), + ) + + if self.device_compilation_only: + return kernel_operands, kernel_arg_types, kernel_arg_attrs + + kernel_operands, kernel_arg_types, kernel_arg_attrs, _ = ( + self._generate_jit_func_args( + kernel_func, kernel_name, args, kwargs, args_spec, is_host=False + ) + ) + + log().debug("Final kernel_operands: %s", ", ".join(map(str, kernel_operands))) + log().debug("Final kernel_arg_types: %s", ", ".join(map(str, kernel_arg_types))) + log().debug("Final kernel_arg_attrs: %s", ", ".join(map(str, kernel_arg_attrs))) + + assert ( + len(kernel_operands) == len(kernel_arg_types) == len(kernel_arg_attrs) + ), "Size of kernel_operands, kernel_arg_types and kernel_arg_attrs must be equal" + + return kernel_operands, kernel_arg_types, kernel_arg_attrs + + def kernel_launcher(self, *dargs, **dkwargs): + def decorator(funcBody): + @wraps(funcBody) + def kernel_wrapper(*args, **kwargs): + """ + Base decorator for generating kernel function + + This decorator provides a template for kernel function generation + including kernel function header/body and kernel launch op at call site + + Optional arguments (with default value in <>): + - requiredArgs <[]>: specifies the mandatory arguments that must present in kernel function signature + the args will be validated and collected as a namedtuple + - optionalArgs <[]>: specifies the optional arguments that might present in kernel function signature + the args will be collected (if present) as a namedtuple + - unitAttrNames <[]>: specifies the name(s) of ir.UnitAttr to be set for kernel function op + - valueAttrDict <{}>: specifies the name(s) and value(s) of ir.Attribute to be set for kernel function op + - kernelGenHelper : specifies the mandatory customized kernel generation helper class (derived from _KernelGenHelper) + + Return value: + A namedtuple "KernelReturns" is returned with following fields: + - kernel_func_ret: the return of the kernel function + - launch_op_ret: the return of the launch op + """ + + requiredArgs = dkwargs.get("requiredArgs", []) + optionalArgs = dkwargs.get("optionalArgs", []) + unitAttrNames = dkwargs.get("unitAttrNames", []) + valueAttrDict = dkwargs.get("valueAttrDict", {}) + kernelGenHelper = dkwargs.get("kernelGenHelper", None) + + kernel_name = funcBody.__name__ + args_spec = inspect.getfullargspec(funcBody) + self.funcBody = funcBody + + # Give each kernel a unique name. (The same kernel may be + # called multiple times, resulting in multiple kernel traces.) + # The mangled name of Python function is part of the name to + # improve readability. + kernel_name = f"kernel_{self.mangle_name(kernel_name, args, args_spec)}_{self.num_kernels}" + self.num_kernels += 1 + + # Step 0. Preprocess the arguments + def extract_args(argNames, assertIfNone=False) -> list: + extracted = [] + for name in argNames: + value = kwargs.pop(name, None) + if assertIfNone and value is None: + raise DSLRuntimeError( + f"{name} is required for {kernel_name}" + ) + extracted.append(value) + + return extracted + + RequiredArgs = namedtuple("RequiredArgs", requiredArgs) + req_args = ( + RequiredArgs._make(extract_args(requiredArgs, assertIfNone=True)) + if requiredArgs + else None + ) + OptionalArgs = namedtuple("OptionalArgs", optionalArgs) + opt_args = ( + OptionalArgs._make(extract_args(optionalArgs)) + if optionalArgs + else None + ) + assert ( + kernelGenHelper is not None + ), "kernelGenHelper should be explicitly specified!" + + # check arguments + sig = self._check_arg_count(*args, **kwargs) + + # Canonicalize the input arguments + canonicalized_args, canonicalized_kwargs = self._canonicalize_args( + sig, *args, **kwargs + ) + + kernel_operands, kernel_types, kernel_arg_attrs = ( + self.generate_kernel_operands_and_types( + funcBody, + kernel_name, + args_spec, + canonicalized_args, + canonicalized_kwargs, + ) + ) + + with self._enter_gpu_module(): + log().debug("Generating device kernel") + if self.device_compilation_only: + log().debug("Generating cuda-python arguments") + # Convert input arguments to MLIR arguments + self.exe_args, kernel_types, _ = ( + self.generate_mlir_function_types( + funcBody, + kernel_name, + canonicalized_args, + canonicalized_kwargs, + args_spec, + ) + ) + + helper = kernelGenHelper() + loc = self.get_location() + fop = helper.generate_func_op( + kernel_types, kernel_arg_attrs, kernel_name, loc + ) + log().debug(f"Kernel function op: {fop}") + for attr in unitAttrNames: + fop.attributes[attr] = ir.UnitAttr.get() + for key, val in valueAttrDict.items(): + fop.attributes[key] = val + + fop.sym_visibility = ir.StringAttr.get("public") + with ir.InsertionPoint(helper.get_func_body_start()): + ir_args, ir_kwargs = self.generate_execution_arguments( + canonicalized_args, canonicalized_kwargs, fop, args_spec + ) + log().debug( + f"IR arguments - args: {ir_args} ; kwargs: {ir_kwargs}" + ) + # Call user function body + kernel_ret = funcBody(*ir_args, **ir_kwargs) + helper.generate_func_ret_op() + + # Step 3. Generate call site `launch_func` + kernel_sym = ir.SymbolRefAttr.get(["kernels", kernel_name]) + launch_ret = helper.generate_launch_op( + kernelSym=kernel_sym, + kernelOperands=kernel_operands, + requiredArgs=req_args, + optionalArgs=opt_args, + ) + + KernelReturns = namedtuple( + "KernelReturns", ["kernel_func_ret", "launch_op_ret"] + ) + result = KernelReturns( + kernel_func_ret=kernel_ret, launch_op_ret=launch_ret + ) + log().debug(f"Kernel result: {result}, kernel name: {kernel_name}") + return result, kernel_name + + return kernel_wrapper + + if len(dargs) == 1 and callable(dargs[0]): + return decorator(dargs[0]) + else: + return decorator diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fa683477f3fb5b18f5459e19bdd468432590b952 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/env_manager.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides utilities for the environment variables setup. + +It provides an EnvironmentVarManager, which reads environment variables for the DSL +and caches them for efficient access. + +It also provides utilities to automatically setup a subset of environment variables +based on heuristics. +""" + +import os +import sys +import shutil +import glob +from pathlib import Path +from functools import lru_cache +from typing import Any + +from ..base_dsl.runtime.cuda import get_compute_capability_major_minor +from .utils.logger import log + +IS_WINDOWS = sys.platform == "win32" +CLIB_EXT = ".dll" if IS_WINDOWS else ".so" + +# ============================================================================= +# Environment Variable Helpers +# ============================================================================= + + +@lru_cache(maxsize=None) +def get_str_env_var(var_name, default_value=None): + value = os.getenv(var_name) + return value if value is not None else default_value + + +@lru_cache(maxsize=None) +def get_bool_env_var(var_name, default_value=False): + value = get_str_env_var(var_name) + if value is None: + return default_value + return value not in {"False", "0", ""} + + +@lru_cache(maxsize=None) +def get_int_env_var(var_name, default_value=0): + value = get_str_env_var(var_name) + return int(value) if value and value.isdigit() else default_value + + +@lru_cache(maxsize=None) +def has_env_var(var_name): + return os.getenv(var_name) is not None + + +def detect_gpu_arch(prefix): + """ + Attempts to detect the machine's GPU architecture. + + Returns: + A string representing the GPU architecture (e.g. "70" for compute capability 7.0), + or a default value(e.g. "sm_100") if the GPU architecture cannot be determined. + """ + arch = (None, None) + try: + arch = get_compute_capability_major_minor() + except Exception as e: + log().info(f"Failed to get CUDA compute capability: {e}") + + if arch == (None, None): + # default to sm_100 + arch = (10, 0) + + major, minor = arch + suffix = "" + if major >= 9: + suffix = "a" + + return f"sm_{major}{minor}{suffix}" + + +def find_libs_in_ancestors(start, target_libs, lib_folder_guesses): + """ + Search ancestor directories for a candidate library folder containing all required libraries. + + Starting from the given path, this function traverses up through each parent directory. + For every ancestor, it checks candidate subdirectories (specified by lib_folder_guesses) + for files that match the required library extension (CLIB_EXT). Library file names are + canonicalized by removing the "lib" prefix from their stem. If a candidate directory contains + all of the required libraries (as specified in target_libs), the function returns a list of + absolute paths to these library files. + + Parameters: + start (str or Path): The starting directory from which to begin the search. + target_libs (iterable of str): A collection of required library names (without the "lib" prefix). + lib_folder_guesses (iterable of str): Relative paths from an ancestor directory that may contain the libraries. + + Returns: + list[str] or None: A list of resolved paths to the required library files if found; otherwise, None. + """ + # Traverse through all parent directories of the resolved starting path. + for ancestor in Path(start).resolve().parents: + # Iterate over each candidate relative directory path. + for rel_path in lib_folder_guesses: + target_dir = ancestor / rel_path + # Skip if the candidate directory does not exist. + if not target_dir.is_dir(): + continue + + # Initialize a list to hold the resolved paths of matching library files. + libs_cand = [] + # Create a set of the remaining libraries we need to find. + remaining_libs = set(target_libs) + + # Iterate over all items in the candidate directory. + for p in target_dir.iterdir(): + # Consider only files with the expected library extension. + if p.suffix == CLIB_EXT: + # Canonicalize the library name by removing the "lib" prefix. + lib_name = p.stem.removeprefix("lib") + # If this library is required, add its resolved path and mark it as found. + if lib_name in remaining_libs: + libs_cand.append(str(p.resolve())) + remaining_libs.remove(lib_name) + + # If all required libraries have been found, return the list of library paths. + if len(remaining_libs) == 0: + return libs_cand + + # Return None if no candidate directory contains all required libraries. + return None + + +def _find_cuda_home(): + """Find the CUDA installation path using a series of heuristic methods. + Methods below are checked in order, and the function returns on first match: + 1. Checking the environment variables CUDA_HOME and CUDA_PATH. + 2. Searching for the 'nvcc' compiler in the system PATH and deriving the path of cuda. + 3. Scanning common installation directories based on the operating system. + - On Windows systems (when IS_WINDOWS is True), it searches in: + C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.* + - On Unix-like systems, it searches in: + /usr/local/cuda* + + Returns: + Optional[str]: The absolute CUDA installation path if found; otherwise, None. + + Note: + The variable IS_WINDOWS is defined in the module scope. + """ + # Guess #1 + cuda_home = get_str_env_var("CUDA_HOME") or get_str_env_var("CUDA_PATH") + if cuda_home is None: + # Guess #2 + nvcc_path = shutil.which("nvcc") + if nvcc_path is not None: + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + else: + # Guess #3 + if IS_WINDOWS: + glob_pat = "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*" + else: + glob_pat = "/usr/local/cuda*" + cuda_homes = glob.glob(glob_pat) + if len(cuda_homes) == 0: + cuda_home = "" + else: + cuda_home = cuda_homes[0] + if not os.path.exists(cuda_home): + cuda_home = None + return cuda_home + + +def get_cuda_toolkit_path(): + """ + Get cuda_toolkit_path. It returns get_str_env_var('CUDA_TOOLKIT_PATH') if + set. Otherwise, attempts to discover a valid CUDA toolkit location and + return. If not found, return None. + """ + # Check if the environment variable is already set, if so, return it immediately. + try: + cuda_toolkit_path_existing = get_str_env_var("CUDA_TOOLKIT_PATH") + if cuda_toolkit_path_existing: + return cuda_toolkit_path_existing + + found_cuda_home = _find_cuda_home() + if found_cuda_home: + return found_cuda_home + except Exception as e: + log().info("default_env: exception on get_cuda_toolkit_path", e) + return None + + +def get_prefix_dsl_libs(prefix: str): + """ + Returns get_str_env_var('{prefix}_LIBS') if set. + Otherwise, attempts to discover libs based on heuristics and return + If not found, return None. + """ + # Check if the environment variable is already set, if so, return it immediately. + try: + prefix_libs_existing = get_str_env_var(f"{prefix}_LIBS") + if prefix_libs_existing: + return prefix_libs_existing + + def get_libs_cand(start): + target_libs = { + "mlir_c_runner_utils", + "mlir_runner_utils", + "mlir_cuda_runtime", + } + lib_folder_guesses = [ + "lib", + ] + + libs_cand = find_libs_in_ancestors(start, target_libs, lib_folder_guesses) + if libs_cand: + dsl_libs = ":".join(libs_cand) + return dsl_libs + + return None + + # find from install folder + dsl_libs = get_libs_cand(__file__) + + if not dsl_libs: + # try to find from build folder structure + dsl_libs = get_libs_cand(Path(__file__).parent.parent.resolve()) + + return dsl_libs + + except Exception as e: + log().info(f"default_env: exception on get_prefix_dsl_libs", e) + return None + + +class EnvironmentVarManager: + """Manages environment variables for configuration options. + + Printing options: + - [DSL_NAME]_LOG_TO_CONSOLE: Print logging to stderr (default: False) + - [DSL_NAME]_PRINT_AFTER_PREPROCESSOR: Print after preprocess (default: False) + - [DSL_NAME]_PRINT_IR: Print generated IR (default: False) + - [DSL_NAME]_FILTER_STACKTRACE: Filter internal stacktrace (default: True) + File options: + - [DSL_NAME]_KEEP_IR: Save generated IR in a file (default: False) + - [DSL_NAME]_LOG_TO_FILE: Store all logging into a file, excluding COMPILE_LOGS (default: False) + Other options: + - [DSL_NAME]_LOG_LEVEL: Logging level to set, for LOG_TO_CONSOLE or LOG_TO_FILE (default: 1). + - [DSL_NAME]_DRYRUN: Generates IR only (default: False) + - [DSL_NAME]_ARCH: GPU architecture (default: "sm_100") + - [DSL_NAME]_WARNINGS_AS_ERRORS: Enable warnings as error (default: False) + - [DSL_NAME]_WARNINGS_IGNORE: Ignore warnings (default: False) + - [DSL_NAME]_ENABLE_OPTIMIZATION_WARNINGS: Enable warnings of optimization warnings (default: False) + - [DSL_NAME]_JIT_TIME_PROFILING: Whether or not to profile the IR generation/compilation/execution time (default: False) + - [DSL_NAME]_DISABLE_FILE_CACHING: Disable file caching (default: False) + - [DSL_NAME]_FILE_CACHING_CAPACITY: Limits the number of the cache save/load files (default: 1000) + - [DSL_NAME]_LIBS: Path to dependent shared libraries (default: None) + - [DSL_NAME]_NO_SOURCE_LOCATION: Generate source location (default: False) + """ + + def __init__(self, prefix="DSL"): + self.prefix = prefix # change if needed + + # Printing options + self.print_after_preprocessor = get_bool_env_var( + f"{prefix}_PRINT_AFTER_PREPROCESSOR", False + ) + self.printIR = get_bool_env_var(f"{prefix}_PRINT_IR", False) + self.filterStacktrace = get_bool_env_var(f"{prefix}_FILTER_STACKTRACE", True) + # File options + self.keepIR = get_bool_env_var(f"{prefix}_KEEP_IR", False) + # Logging options + self.log_to_console = get_bool_env_var(f"{prefix}_LOG_TO_CONSOLE", False) + self.log_to_file = get_bool_env_var(f"{prefix}_LOG_TO_FILE", False) + if ( + has_env_var(f"{prefix}_LOG_LEVEL") + and not self.log_to_console + and not self.log_to_file + ): + log().warning( + f"Log level was set, but neither logging to file ({prefix}_LOG_TO_FILE) nor logging to console ({prefix}_LOG_TO_CONSOLE) is enabled!" + ) + self.log_level = get_int_env_var(f"{prefix}_LOG_LEVEL", 1) + + # Other options + self.dryrun = get_bool_env_var(f"{prefix}_DRYRUN", False) + self.arch = get_str_env_var(f"{prefix}_ARCH", detect_gpu_arch(prefix)) + self.warnings_as_errors = get_bool_env_var( + f"{prefix}_WARNINGS_AS_ERRORS", False + ) + self.warnings_ignore = get_bool_env_var(f"{prefix}_WARNINGS_IGNORE", False) + self.enable_optimization_warnings = get_bool_env_var( + f"{prefix}_ENABLE_OPTIMIZATION_WARNINGS", False + ) + self.jitTimeProfiling = get_bool_env_var(f"{prefix}_JIT_TIME_PROFILING", False) + self.disable_file_caching = get_bool_env_var( + f"{prefix}_DISABLE_FILE_CACHING", False + ) + self.file_caching_capacity = get_int_env_var( + f"{prefix}_FILE_CACHING_CAPACITY", 1000 + ) + self.generate_source_location = not get_bool_env_var( + f"{prefix}_NO_SOURCE_LOCATION", False + ) + # set cuda + self.cuda_toolkit = get_cuda_toolkit_path() + + # set mlir shared libraries + self.shared_libs = get_prefix_dsl_libs(prefix) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..83268009c85ef64967d6a81ab886ebeb704f140d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/jit_executor.py @@ -0,0 +1,357 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides jit executor related classes +""" +import ctypes +import inspect +import io +from typing import get_origin + +import numpy as np + +# MLIR modules imports +from .._mlir import ir + +# Local modules imports +from . import typing as t +from .common import DSLRuntimeError +from .runtime import cuda as cuda_helpers +from .runtime.jit_arg_adapters import JitArgAdapterRegistry, is_arg_spec_constexpr +from .typing import get_c_pointers +from .utils.logger import log +from .utils.timer import timer + + +class CudaSingleModule: + def __init__(self, cuda_module, kernel_ptr): + self.cuda_module = cuda_module + self.kernel_ptr = kernel_ptr + + +class CudaModules: + def __init__(self, modules, args): + # list of CudaSingleModule + self.modules = modules + # extra kernel ptr arguments for launch + self.args = args + + +class JitExecutor: + def __init__( + self, + dsl, + engine, + capi_func, + ir_module, + args_spec, + function_name, + cuda_modules: CudaModules = None, + jit_time_profiling=False, + ): + self.dsl = dsl + self.engine = engine + self.capi_func = capi_func + self.ir_module = ir_module + self.args_spec = args_spec + self.function_name = function_name + if args_spec is not None: + self.original_args_spec = args_spec + self.args_spec = self.filter_runtime_arg_spec(args_spec) + # cuda kernels + self.cuda_modules = cuda_modules + self.jit_time_profiling = jit_time_profiling + + def filter_runtime_arg_spec(self, arg_spec: inspect.FullArgSpec): + runtime_args = [] + runtime_annotations = {} + runtime_defaults = [] + + # Calculate the offset where defaults start in the original args + if arg_spec.defaults: + defaults_start_idx = len(arg_spec.args) - len(arg_spec.defaults) + else: + defaults_start_idx = len(arg_spec.args) + + # Filter arguments and maintain their properties + for i, arg_name in enumerate(arg_spec.args): + arg_type = arg_spec.annotations.get(arg_name, None) + + # Skip compile-time arguments + if is_arg_spec_constexpr(arg_type, arg_name, i, self.function_name): + continue + + # Keep runtime arguments + runtime_args.append(arg_name) + if arg_name in arg_spec.annotations: + runtime_annotations[arg_name] = arg_type + + # Keep corresponding default if it exists + if i >= defaults_start_idx: + default_idx = i - defaults_start_idx + runtime_defaults.append(arg_spec.defaults[default_idx]) + + # Filter kwonlyargs and their defaults + runtime_kwonlyargs = [] + runtime_kwonlydefaults = {} + + if arg_spec.kwonlyargs: + for kwarg in arg_spec.kwonlyargs: + arg_type = arg_spec.annotations.get(kwarg, None) + + # Apply same filtering logic + if is_arg_spec_constexpr(arg_type, kwarg, i, self.function_name): + continue + + runtime_kwonlyargs.append(kwarg) + if kwarg in arg_spec.annotations: + runtime_annotations[kwarg] = arg_type + if arg_spec.kwonlydefaults and kwarg in arg_spec.kwonlydefaults: + runtime_kwonlydefaults[kwarg] = arg_spec.kwonlydefaults[kwarg] + + # Convert runtime_defaults to tuple if not empty (as expected by FullArgSpec) + runtime_defaults = tuple(runtime_defaults) if runtime_defaults else None + + return inspect.FullArgSpec( + args=runtime_args, + varargs=arg_spec.varargs, # Keep original varargs + varkw=arg_spec.varkw, # Keep original varkw + defaults=runtime_defaults, + kwonlyargs=runtime_kwonlyargs, + kwonlydefaults=runtime_kwonlydefaults if runtime_kwonlydefaults else None, + annotations=runtime_annotations, + ) + + def __del__(self): + if self.cuda_modules: + cuda_modules = [module.cuda_module for module in self.cuda_modules.modules] + for module in set(cuda_modules): + cuda_helpers.unload_cubin_module(module) + + def get_constexpr_args(self) -> list[dict[str, int | str]]: + """ + This function returns the constexpr args that have been pruned from the original function signature. + The return type is a list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). + + :return: list of dicts, each dict contains the argument index (argument_index) and argument name (argument_name). + :rtype: list[dict[str, int | str]] + """ + if self.original_args_spec is None: + return list() + constexpr_args = list() + for i, arg_name in enumerate(self.original_args_spec.args): + if arg_name not in self.args_spec.args: + constexpr_args.append({"argument_index": i, "argument_name": arg_name}) + + if self.original_args_spec.kwonlyargs: + for kwarg in self.original_args_spec.kwonlyargs: + if kwarg not in self.args_spec.kwonlyargs: + constexpr_args.append( + {"argument_index": None, "argument_name": kwarg} + ) + return constexpr_args + + def generate_execution_args(self, args, kwargs, args_spec: inspect.FullArgSpec): + """ + This function is the prune version of `generate_mlir_function_types` which only generates execution args + to get rid of mlir context. + """ + + # Process positional arguments with defaults + rectified_args = list(args) + if args_spec.defaults and len(args) < len(args_spec.args): + rectified_args.extend(args_spec.defaults[len(args) - len(args_spec.args) :]) + for k, v in kwargs.items(): + if k in args_spec.args: + idx = args_spec.args.index(k) + if idx < len(rectified_args): + rectified_args[idx] = v + else: + rectified_args.append(v) + + # Process keyword arguments + rectified_kwargs = {k: v for k, v in kwargs.items() if k not in args_spec.args} + if args_spec.kwonlydefaults and len(rectified_kwargs) < len( + args_spec.kwonlyargs + ): + rectified_kwargs.update(args_spec.kwonlydefaults) + + # args/kwargs must match arg_specs + if len(rectified_args) != len(args_spec.args) or len(rectified_kwargs) != len( + args_spec.kwonlyargs + ): + raise DSLRuntimeError( + "input args/kwargs length does not match runtime function signature!", + context={ + "input args length": len(rectified_args), + "input kwargs length": len(rectified_kwargs), + "function signature args length": len(args_spec.args), + "function signature kwonlyargs length": len(args_spec.kwonlyargs), + }, + ) + + exe_args = [] + adapted_args = [] + input_args = rectified_args + list(rectified_kwargs.values()) + input_arg_names = args_spec.args + args_spec.kwonlyargs + for arg, arg_name in zip(input_args, input_arg_names): + # short-cut for args already converted + if hasattr(arg, "__c_pointers__"): + exe_args.extend(arg.__c_pointers__()) + continue + + arg_type = args_spec.annotations.get(arg_name, None) + + # Implicit cast to NumericMeta + if isinstance(arg_type, t.NumericMeta): + arg = t.cast(arg, arg_type) + else: + # If not any known type, try registered adapter to do the conversion + adapter = JitArgAdapterRegistry.get_registered_adapter(type(arg)) + if adapter: + arg = adapter(arg) + adapted_args.append(arg) + + exe_args.extend(get_c_pointers(arg)) + + return exe_args, adapted_args + + def __call__(self, *args, **kwargs): + exe_args, adapted_args = self.generate_execution_args( + args, kwargs, self.args_spec + ) + + self.run_compiled_program(exe_args) + + # Assume each execution args has type `c_void_p` to reduce the overhead of `ctypes.cast`. + def get_invoke_packed_args(self, exe_args): + if self.cuda_modules: + exe_args += self.cuda_modules.args + packed_args = (ctypes.c_void_p * len(exe_args))() + for argNum in range(len(exe_args)): + packed_args[argNum] = exe_args[argNum] + return packed_args + + def run_compiled_program(self, exe_args): + if self.jit_time_profiling: + profiler = timer(enable=True) + try: + packed_args = profiler(self.get_invoke_packed_args)(exe_args) + profiler(self.capi_func)(packed_args) + except Exception as e: + raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) + else: + try: + packed_args = self.get_invoke_packed_args(exe_args) + self.capi_func(packed_args) + except Exception as e: + raise DSLRuntimeError(f"💥💥💥 Runtime Crash 💥💥💥", cause=e) + + def update_jit_cuda_modules(self, kernel_symbols): + # preload cuda module from compiled cubin in ir and store to jit_executor.kernels. + if len(kernel_symbols) > 0: + extra_args = [] + module = self.ir_module + cuda_kernel_cache = dict() + cuda_driver_version = cuda_helpers.get_driver_version() + for sym in kernel_symbols: + if sym not in cuda_kernel_cache: + log().debug(f"Loading CUDA module for symbol: {sym}") + + # load cuda module/get function pointer from module and cache + def walk_callback(sym, func_sym, cubin_data): + cubin_module = cuda_helpers.load_cubin_module_data(cubin_data) + kernel_ptr = cuda_helpers.get_kernel_function( + cubin_module, func_sym + ) + # Enable non-portable cluster size for CUDA version 11.8 or higher. + if cuda_driver_version >= 11080: + cuda_helpers.set_kernel_attribute( + kernel_ptr, + cuda_helpers.cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1, + ) + cuda_kernel_cache[sym] = CudaSingleModule( + cubin_module, kernel_ptr + ) + + self.walk_module_and_get_cubin_data(module, sym, walk_callback) + else: + log().debug(f"Symbol {sym} already in cache") + # check if kernel is empty. + if sym in cuda_kernel_cache: + extra_args.append( + ctypes.c_void_p(cuda_kernel_cache[sym].kernel_ptr.getPtr()) + ) + # store to the jit result if jit result is cached. + self.cuda_modules = CudaModules(cuda_kernel_cache.values(), extra_args) + + return self + + def _get_escaped_cubin_bytes(self, cubin_data): + """This function escapes cubin data from mlir raw bytecode to executable binary bytes""" + + def ishex(inp): + return ( + inp in range(0x30, 0x3A) + or inp in range(0x61, 0x67) + or inp in range(0x41, 0x47) + ) + + converted = bytearray() + idx = 0 + while idx < len(cubin_data): + # escape the original bytes + if cubin_data[idx] == 0x5C: + # if data of idx is b'\\' + if ishex(cubin_data[idx + 1]) and ishex(cubin_data[idx + 2]): + converted += bytearray.fromhex( + cubin_data[idx + 1 : idx + 3].decode() + ) + idx += 3 + elif cubin_data[idx + 1] == 0x5C: + converted.append(cubin_data[idx]) + idx += 2 + else: + # no escape, directly write + converted.append(cubin_data[idx]) + idx += 1 + return bytes(converted) + + def walk_module_and_get_cubin_data(self, module, sym, callback): + """This function is used to walk gpu binary op, extract the cubin inside, and process cubin data with callback.""" + + def walk_gpu_binary_op(op): + if op.name != "gpu.binary": + return ir.WalkResult.ADVANCE + s = io.BytesIO() + op.write_bytecode(s) + cubin_data = s.getvalue() + if sym.encode() not in cubin_data: + return ir.WalkResult.ADVANCE + + if ( + "kernels" != op.opview.sym_name.value + and sym != op.opview.sym_name.value + ): + return ir.WalkResult.ADVANCE + # function symbol of kernel(gpu.launch_func) is equal to sym name in mlir + func_sym = sym + if sym == op.opview.sym_name.value and not sym.endswith("_kernel"): + func_sym = sym.rsplit("_", 1)[0] + + cubin_data = cubin_data.split(b'bin = "')[1].split(b'">')[0] + cubin_data = self._get_escaped_cubin_bytes(cubin_data) + callback(sym, func_sym, cubin_data) + return ir.WalkResult.ADVANCE + + module.operation.walk(walk_gpu_binary_op) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc475fdda59450f07c35ae244d6223446470c6d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a runtime utility functions that are needed for +the DSL. +""" + +from . import dlpack_types +from . import cuda +from . import jit_arg_adapters + +__all__ = [ + "dlpack_types", + "cuda", + "jit_arg_adapters", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..97ae778c0cd5ae19d20fac8e045e2021832f5bbc --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/cuda.py @@ -0,0 +1,476 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides CUDA Python helper functions +""" + + +from functools import lru_cache +from dataclasses import dataclass +from typing import List, Optional +import numpy as np +import os +import ctypes + +import cuda.bindings.driver as cuda +import cuda.bindings.nvrtc as nvrtc + +# MLIR imports +from ..._mlir import ir +from ..._mlir.dialects import gpu + +# Local module imports +from ..utils.logger import log as _log +from ..common import * +from .jit_arg_adapters import JitArgAdapterRegistry + + +# ============================================================================= +# Utils +# ============================================================================= + + +def _cudaGetErrorEnum(error): + if isinstance(error, cuda.CUresult): + err, name = cuda.cuGetErrorName(error) + return name if err == cuda.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise DSLRuntimeError("Unknown error type: {}".format(error)) + + +def _get_gpu_arch_info(major, minor): + """Get GPU architecture information and compatibility details.""" + gpu_arch_map = { + (7, 0): ("Volta", "sm_70", ["sm_70"]), # V100 + (7, 5): ("Turing", "sm_75", ["sm_75"]), # RTX 20 Series, Quadro RTX + (8, 0): ("Ampere", "sm_80", ["sm_80"]), # A100 + (8, 6): ("Ampere", "sm_86", ["sm_86", "sm_80"]), # RTX 30 Series + (8, 9): ("Ada", "sm_89", ["sm_89", "sm_86"]), # RTX 40 Series + (8, 7): ("Ampere", "sm_87", ["sm_87", "sm_86", "sm_80"]), # A10, A40 + (9, 0): ("Hopper", "sm_90a", ["sm_90a"]), # H100 + (10, 0): ("Blackwell", "sm_100a", ["sm_100a"]), # B200 + } + return gpu_arch_map.get( + (major, minor), ("Unknown", f"sm_{major}{minor}", [f"sm_{major}{minor}"]) + ) + + +def get_compute_capability_major_minor(device_id: int = 0): + """ + Returns the compute capability of the CUDA device as a tuple of (major, minor). + For example: (8, 0) for Ampere, (9, 0) for Hopper, (10, 0) for Blackwell. + Returns None on failure. + """ + try: + checkCudaErrors(cuda.cuInit(0)) + device = checkCudaErrors(cuda.cuDeviceGet(device_id)) + major = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, + device, + ) + ) + minor = checkCudaErrors( + cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, + device, + ) + ) + return major, minor + except RuntimeError as e: + _log().info(f"Failed to get CUDA compute capability: {e}") + return None, None + + +@dataclass +class DeviceInfo: + """Data class to store CUDA device information.""" + + device_count: int = 0 + current_device: int = 0 + device_name: Optional[str] = None + major_version: Optional[int] = None + minor_version: Optional[int] = None + arch_name: Optional[str] = None + sm_arch: Optional[str] = None + compatible_archs: Optional[List[str]] = None + memory_gb: Optional[float] = None + target_arch: Optional[str] = None + error_message: Optional[str] = None + initialization_failed: bool = False + + def pretty_str(self) -> str: + """ + Convert DeviceInfo to a formatted string for display. + """ + info = "" + + if self.initialization_failed: + return f"{Colors.BOLD}- CUDA initialization failed{Colors.RESET}" + + if self.error_message: + return f"{Colors.BOLD}- Failed to get GPU info: {self.error_message}{Colors.RESET}" + + if self.device_count > 0: + info += f"{Colors.BOLD}- CUDA devices available: {self.device_count} (current: {self.current_device})\n" + + if self.major_version is not None and self.minor_version is not None: + info += f"- Architecture: {Colors.BLUE}{self.arch_name}{Colors.RESET} ({Colors.GREEN}{self.sm_arch}{Colors.RESET})\n" + info += f"- Compatible SM archs: {Colors.GREEN}{', '.join(self.compatible_archs or [])}{Colors.RESET}\n" + + if self.memory_gb is not None: + info += f"- Total Memory: {Colors.BLUE}{self.memory_gb:.2f} GB{Colors.RESET}\n" + + else: + info += f"- Compute capability: unknown\n" + info += f"- SM arch: unknown{Colors.RESET}\n" + else: + info += f"- No devices available\n" + + return info + + +def get_device_info() -> DeviceInfo: + """ + Get detailed information about CUDA devices. + Returns a DeviceInfo dataclass with device information. + """ + device_info = DeviceInfo() + + # Initialize CUDA if not already initialized + try: + result = cuda.cuInit(0) + if result[0].value: # Check for error + device_info.initialization_failed = True + return device_info + except: + pass + + try: + # Get device count + result = cuda.cuDeviceGetCount() + device_info.device_count = result[1] if result[0].value == 0 else 0 + + if device_info.device_count > 0: + # Get current device + try: + result = cuda.cuCtxGetDevice() + if result[0].value == 0: + device_info.current_device = result[1] + except: + pass + + # Get device name + try: + name_result = cuda.cuDeviceGetName(100, device_info.current_device) + if name_result[0].value == 0: + device_info.device_name = name_result[1] + except: + pass + + # Get compute capability and architecture info + try: + major, minor = get_compute_capability_major_minor( + device_info.current_device + ) + + # Check if we successfully got the compute capability + if major is not None and minor is not None: + device_info.major_version = major + device_info.minor_version = minor + + arch_name, sm_arch, compatible_archs = _get_gpu_arch_info( + device_info.major_version, device_info.minor_version + ) + + device_info.arch_name = arch_name + device_info.sm_arch = sm_arch + device_info.compatible_archs = compatible_archs + + # Get memory info + try: + total_mem = cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_TOTAL_MEMORY, + device_info.current_device, + ) + if total_mem[0].value == 0: + device_info.memory_gb = total_mem[1] / ( + 1024 * 1024 * 1024 + ) # Convert to GB + except: + pass + + except Exception as e: + pass # Compute capability info will remain None + + except Exception as e: + device_info.error_message = str(e) + + return device_info + + +def checkCudaErrors(result): + """Check CUDA errors and provide detailed error messages.""" + if result[0].value: + error_code = result[0].value + error_name = _cudaGetErrorEnum(result[0]) + + raise DSLCudaRuntimeError(error_code, error_name) + + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +# ============================================================================= +# Driver Helpers +# ============================================================================= + + +@lru_cache(maxsize=1) +def initialize_cuda_context(device_id: int = 0, flags: int = 0): + """ + Initializes the CUDA context for a specified device. + """ + # Initialize CUDA Driver API + _log().info(f"cuInit {flags}") + checkCudaErrors(cuda.cuInit(flags)) + # Retrieve handle for device + _log().info(f"cuDeviceGet {device_id}") + cuDevice = checkCudaErrors(cuda.cuDeviceGet(device_id)) + _log().info(f"{cuDevice} <-- cuDeviceGet") + # Create context + _log().info(f"cuCtxCreate {0} {cuDevice}") + if cuda.CUDA_VERSION >= 13000: + # Use cuCtxCreate_v4 API with explicit CUctxCreateParams None, since v2 + # and v3 API has been removed from CTK 13. + # See https://github.com/NVIDIA/cuda-python/pull/792 + context = checkCudaErrors(cuda.cuCtxCreate(None, 0, cuDevice)) + else: + context = checkCudaErrors(cuda.cuCtxCreate(0, cuDevice)) + _log().info(f"{context} <-- cuCtxCreate") + + return context + + +def load_cubin_module(cubin_file): + """ + Loads a CUBIN file and returns the module. + """ + # Load CUBIN file as binary data + _log().info(f"read cubin {cubin_file}") + with open(cubin_file, "rb") as f: + cubin_data = f.read() + # Load module data + _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}") + module = checkCudaErrors( + cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data) + ) + return module + + +def unload_cubin_module(module): + """ + Unloads a CUBIN module. + """ + _log().info(f"cuModuleUnload {module}") + checkCudaErrors(cuda.cuModuleUnload(module)) + + +def load_cubin_module_data(cubin_data): + """ + Loads a CUBIN from data and returns the module. + """ + # Load module data + _log().info(f"cuModuleLoadData {np.char.array(cubin_data).ctypes.data}") + module = checkCudaErrors( + cuda.cuModuleLoadData(np.char.array(cubin_data).ctypes.data) + ) + return module + + +def get_kernel_function(module, kernel_name): + """ + Retrieves the kernel function from the module. + """ + _log().info(f"cuModuleGetFunction {module} {kernel_name}") + kernel = checkCudaErrors( + cuda.cuModuleGetFunction(module, bytes(kernel_name, "utf-8")) + ) + _log().info(f"{kernel} <-- cuModuleGetFunction") + return kernel + + +def launch_kernel(kernel, grid_dims, block_dims, stream, smem_size, kernel_args=None): + """ + Launches the CUDA kernel. + """ + _log().info( + f"cuLaunchKernel {kernel} grid={grid_dims} blocks={block_dims} smem_size={smem_size} stream={stream} {kernel_args}" + ) + checkCudaErrors( + cuda.cuLaunchKernel( + kernel, + grid_dims[0], + grid_dims[1], + grid_dims[2], + block_dims[0], + block_dims[1], + block_dims[2], + smem_size, # Shared memory size + stream, + kernel_args, + 0, # Extra parameters + ) + ) + + +def stream_sync(stream): + """ + Synchronizes the CUDA stream. + """ + _log().info(f"cuStreamSynchronize {stream}") + checkCudaErrors(cuda.cuStreamSynchronize(stream)) + + +def stream_create(id=0): + """ + Creates the CUDA stream. + """ + _log().info(f"cuStreamCreate {id}") + stream = checkCudaErrors(cuda.cuStreamCreate(id)) + _log().info(f"{stream} <-- cuStreamCreate") + return stream + + +def stream_destroy(stream): + """ + Destroys the CUDA stream. + """ + _log().info(f"cuStreamDestroy {stream}") + checkCudaErrors(cuda.cuStreamDestroy(stream)) + + +def context_destroy(context): + """ + Destroys the CUDA context. + """ + _log().info(f"cuCtxDestroy {context}") + checkCudaErrors(cuda.cuCtxDestroy(context)) + + +def allocate(size_in_bytes: int, stream=None): + """ + Allocate device memory based on numpy host array size. + """ + _log().info("Allocate size_in_bytes=[%s] stream=[%s]", size_in_bytes, stream) + if stream is None: + device_memory = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) + else: + device_memory = checkCudaErrors(cuda.cuMemAllocAsync(size_in_bytes, stream)) + _log().info("Allocated [%s]", device_memory) + return device_memory + + +def deallocate(device_pointer, stream=None): + """ + Deallocate the specified device memory pointer. + """ + _log().info( + "Deallocate device_pointer=[%s] stream=[%s]", hex(int(device_pointer)), stream + ) + if stream is None: + checkCudaErrors(cuda.cuMemFree(device_pointer)) + else: + checkCudaErrors(cuda.cuMemFreeAsync(device_pointer, stream)) + + +def memcpy_h2d(host_pointer, device_pointer, size_in_bytes, stream=None): + """ + Copy data from host to device memory. + """ + _log().info( + "Copy host-to-device host_pointer[%s] device_ptr=[%s] size_in_bytes=[%s] stream=[%s]", + hex(host_pointer), + hex(int(device_pointer)), + size_in_bytes, + stream, + ) + if stream is None: + checkCudaErrors(cuda.cuMemcpyHtoD(device_pointer, host_pointer, size_in_bytes)) + else: + checkCudaErrors( + cuda.cuMemcpyHtoDAsync(device_pointer, host_pointer, size_in_bytes, stream) + ) + + +def memcpy_d2h(host_pointer, device_pointer, size_in_bytes, stream=None): + """ + Copy data from device to host memory. + """ + _log().info( + "Copy device-host-to device_pointer=[%s] host_pointer[%s] size_in_bytes=[%s] stream=[%s]", + hex(int(device_pointer)), + hex(host_pointer), + size_in_bytes, + stream, + ) + if stream is None: + checkCudaErrors(cuda.cuMemcpyDtoH(host_pointer, device_pointer, size_in_bytes)) + else: + checkCudaErrors( + cuda.cuMemcpyDtoHAsync(host_pointer, device_pointer, size_in_bytes, stream) + ) + + +def default_stream(): + return cuda.CUstream(0) + + +def get_driver_version(): + """ + Returns the CUDA driver version. + """ + return checkCudaErrors(cuda.cuDriverGetVersion()) + + +def set_kernel_attribute(kernel, attribute, value): + """ + Sets a CUDA kernel attribute. + """ + return checkCudaErrors(cuda.cuFuncSetAttribute(kernel, attribute, value)) + + +@JitArgAdapterRegistry.register_jit_arg_adapter(cuda.CUstream) +class StreamAdapter: + """ + Convert a CUDA stream to a stream representation for JIT arg generation. + """ + + def __init__(self, arg): + self._arg = arg + self._c_pointer = self._arg.getPtr() + + def __new_from_mlir_values__(self, values): + assert len(values) == 1 + return values[0] + + def __c_pointers__(self): + return [self._c_pointer] + + def __get_mlir_types__(self): + return [gpu.AsyncTokenType.get()] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5addb275b12f2b18e109b0592a87f3044d2fe595 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/device_tensor.py @@ -0,0 +1,121 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import copy + +from . import cuda as cuda_helpers +from .tensor_descriptor import * +from ..common import * + + +def allocate(tensor: TensorDescriptor, stream=None): + """ + Allocates GPU memory + """ + if tensor._check_is_managed_by_framework(): + raise DSLRuntimeError( + "GPU tensors are managed by the framework and cannot be modified." + ) + if not tensor.device_pointer is None: + raise DSLRuntimeError("Tensor is already allocated on the device.") + + tensor.device_pointer = cuda_helpers.allocate(tensor.size_in_bytes, stream) + + log().info("Allocate done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + + +def deallocate(tensor: TensorDescriptor, stream=None): + """ + Deallocates GPU memory + """ + if tensor._check_is_managed_by_framework(): + raise DSLRuntimeError( + "GPU tensors are managed by the framework and cannot be modified." + ) + if tensor.device_pointer is None: + raise DSLRuntimeError("Tensor is not allocated on the device.") + + log().info( + "Deallocating done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer + ) + + cuda_helpers.deallocate(tensor.device_pointer, stream) + tensor.device_pointer = None + + +def copy_to_gpu(tensor: TensorDescriptor, do_allocate=True, stream=None): + """ + Copies data from host memory to the GPU memory. + If do_allocate is True, it first calls allocate + """ + log().info("copyin tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + if do_allocate: + allocate(tensor, stream) + cuda_helpers.memcpy_h2d( + tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream + ) + log().info("copyin done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + return tensor + + +def copy_from_gpu(tensor: TensorDescriptor, do_deallocate=True, stream=None): + """ + Copies data from GPU memory back to the host. + If do_deallocate is True, it calls deallocate + """ + log().info("copyout tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + if tensor._check_is_managed_by_framework(): + raise DSLRuntimeError( + "GPU tensors are managed by the framework and cannot be modified." + ) + if tensor.device_pointer is None: + raise DSLRuntimeError("Tensor is not allocated on the device.") + + cuda_helpers.memcpy_d2h( + tensor.data_ptr, tensor.device_pointer, tensor.size_in_bytes, stream + ) + if do_deallocate: + deallocate(tensor, stream) + log().info("copyout done tensor=[%s] dev_ptr=[%s]", tensor, tensor.device_pointer) + + +def to_gpu(tensor, stream=None) -> TensorDescriptor: + """ + Copies the tensor to the GPU memory from Host memory + """ + if isinstance(tensor, TensorDescriptor): + new_tensor = copy.copy(tensor) + copy_to_gpu(new_tensor, stream=stream) + return new_tensor + + if TensorDescriptor.can_transformed_to_dlpack(tensor): + new_tensor = TensorDescriptor(tensor) + copy_to_gpu(new_tensor, stream=stream) + return new_tensor + + raise DSLRuntimeError("Unsupported type") + + +def from_gpu(tensor, stream=None) -> TensorDescriptor: + """ + Copies the tensor to the GPU memory from Host memory + """ + if isinstance(tensor, TensorDescriptor): + new_tensor = copy.copy(tensor) + copy_from_gpu(new_tensor, stream=stream) + return new_tensor + + if TensorDescriptor.can_transformed_to_dlpack(tensor): + new_tensor = TensorDescriptor(tensor) + copy_from_gpu(new_tensor, stream=stream) + return new_tensor + + raise DSLRuntimeError("Unsupported type") diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py new file mode 100644 index 0000000000000000000000000000000000000000..168c2a9953f74b45cadfcbb6562f89d1bb35cd6d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/dlpack_types.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides helper structs for dlpack. +DLPack is an open standard for in-memory tensor structures, enabling +seamless sharing of tensors across different frameworks. +Learn more at: https://github.com/dmlc/dlpack +""" + +import ctypes +import enum + + +class DLDeviceType(enum.IntEnum): + """Enums for device types based on the DLPack specification.""" + + kDLCPU = 1 + kDLGPU = 2 + kDLCPUPinned = 3 + + +class DLDataTypeCode: + """Enums for data type codes based on the DLPack specification. + + see https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h + """ + + kDLInt = 0 + kDLUInt = 1 + kDLFloat = 2 + kDLOpaqueHandle = 3 + kDLBfloat = 4 + kDLComplex = 5 + kDLBool = 6 + + +class DLDevice(ctypes.Structure): + """Structure representing the device information in DLPack.""" + + _fields_ = [ + ("device_type", ctypes.c_int), # kDLCPU, kDLGPU, etc. + ("device_id", ctypes.c_int), # Device ID (e.g., GPU ID) + ] + + +class DLDataType(ctypes.Structure): + """Structure representing the data type in DLPack.""" + + _fields_ = [ + ("code", ctypes.c_uint8), # Data type code (e.g., kDLFloat) + ("bits", ctypes.c_uint8), # Number of bits per value + ("lanes", ctypes.c_uint16), # Number of lanes + ] + + +class DLTensor(ctypes.Structure): + """Structure representing the DLTensor in DLPack.""" + + _fields_ = [ + ("data", ctypes.c_void_p), # Pointer to tensor data + ("device", DLDevice), # Device info + ("ndim", ctypes.c_int), # Number of dimensions + ("dtype", DLDataType), # Data type + ("shape", ctypes.POINTER(ctypes.c_int64)), # Shape of tensor + ("strides", ctypes.POINTER(ctypes.c_int64)), # Strides of tensor + ("byte_offset", ctypes.c_uint64), # Byte offset to tensor data + ] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..eb998d16d8fb4bcf592f17ce0f23a81d6e11bff6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/jit_arg_adapters.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides runtime utilities for JIT argument conversion in DSL. +""" + +from functools import wraps +from typing import get_origin + +# Local modules imports +from ..common import DSLRuntimeError +from ..typing import ( + Constexpr, + Int32, + Float32, + Boolean, +) + + +def is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func): + """ + Check if the argument spec is a constexpr. + """ + + def _is_reserved_python_func_arg(arg_index, arg_name, func): + """ + Check if the argument is a reserved python function argument. + """ + + if arg_index != 0: + return False + + if arg_name == "self": + return True + + is_classmethod = isinstance(func, classmethod) or ( + hasattr(func, "__func__") and isinstance(func.__func__, classmethod) + ) + return arg_name == "cls" and is_classmethod + + return ( + _is_reserved_python_func_arg(arg_index, arg_name, owning_func) + or (isinstance(arg_spec, type) and issubclass(arg_spec, Constexpr)) + or (get_origin(arg_spec) is Constexpr) + ) + + +def is_argument_constexpr(arg, arg_spec, arg_name, arg_index, owning_func): + """ + Check if the argument is a constexpr. + """ + + def _is_type_argument(arg, arg_annotation): + """ + Check if the argument is a type argument like Type[X] + """ + + return isinstance(arg, type) and ( + arg_annotation is None or get_origin(arg_annotation) is type + ) + + return ( + is_arg_spec_constexpr(arg_spec, arg_name, arg_index, owning_func) + or _is_type_argument(arg, arg_spec) + or arg is None + ) + + +class JitArgAdapterRegistry: + """ + A registry to keep track of the JIT argument adapters. + + An adapter is a callable that converts a Python type to a type with following protocols supported: + - JitArgument + - DynamicExpression + The converted type can then be further processed by DSL to generate arguments for JIT functions. + """ + + # A dictionary with key=type and value=callable + jit_arg_adapter_registry = {} + + @classmethod + def register_jit_arg_adapter(cls, *dargs, **dkwargs): + """ + Register a JIT argument adapter callable + + This can be used as a decorator on any callable like: + + @register_jit_arg_adapter(my_py_type) + def my_adapter_for_my_py_type(arg): + ... + + @register_jit_arg_adapter(my_py_type) + class MyAdapterForMyPythonType: + ... + + The adapters are registered per type. If a type is already registerd, an error will be raised. + """ + + def decorator(*dargs, **dkwargs): + darg_python_ty = dargs[0] + + @wraps(darg_python_ty) + def wrapper(*args, **kwargs): + if len(args) != 1 or not callable(args[0]): + raise DSLRuntimeError( + "a callable must be provided for registering JIT argument adapter" + ) + adapter = args[0] + + if darg_python_ty in cls.jit_arg_adapter_registry: + raise DSLRuntimeError( + f"JIT argument adapter for {darg_python_ty} is already registered!", + context={ + "Registered adapter": cls.jit_arg_adapter_registry[ + darg_python_ty + ], + "Adapter to be registered": adapter, + }, + ) + cls.jit_arg_adapter_registry[darg_python_ty] = adapter + return adapter + + return wrapper + + if len(dargs) > 0: + return decorator(*dargs, **dkwargs) + else: + raise DSLRuntimeError( + "a Python type must be provided for registering JIT argument adapter" + ) + + @classmethod + def get_registered_adapter(cls, ty): + """ + Get the registered JIT argument adapter for the given type. + """ + return cls.jit_arg_adapter_registry.get(ty, None) + + +# ============================================================================= +# JIT Argument Adapters +# ============================================================================= + + +@JitArgAdapterRegistry.register_jit_arg_adapter(int) +@JitArgAdapterRegistry.register_jit_arg_adapter(float) +@JitArgAdapterRegistry.register_jit_arg_adapter(bool) +def _convert_python_scalar(arg): + """ + Convert a Python scalar to a DSL type. + """ + conversion_map = { + int: Int32, + float: Float32, + bool: Boolean, + } + return conversion_map.get(type(arg))(arg) + + +@JitArgAdapterRegistry.register_jit_arg_adapter(tuple) +@JitArgAdapterRegistry.register_jit_arg_adapter(list) +def _convert_python_sequence(arg): + """ + Go through each element in the sequence and convert it to a type that can be + further processed by DSL to generate the corresponding JIT argument(s). + """ + adapted_arg = [] + for elem in arg: + adapter = JitArgAdapterRegistry.get_registered_adapter(type(elem)) + if adapter is not None: + converted_elem = adapter(elem) + adapted_arg.append(converted_elem) + else: + # If no registered adapter is found, just return the original element + adapted_arg.append(elem) + + assert len(adapted_arg) == len(arg) + return type(arg)(adapted_arg) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py new file mode 100644 index 0000000000000000000000000000000000000000..1a992ef68293d6f969ab551b6321c3696c961037 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/runtime/tensor_descriptor.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +# Helpers +import itertools, operator +import ctypes +from . import dlpack_types as _dpack +from .dlpack_runtime import ( + dlpack_to_tensor_desc, + get_tensor_desc_data_ptr, + get_tensor_desc_is_in_device, + get_tensor_desc_element_type, + get_tensor_desc_shape, + get_tensor_desc_stride, + get_tensor_desc_element_size_in_bytes, + get_tensor_desc_ndim, + get_tensor_desc_dtype_code, + get_tensor_desc_dtype_bits, + get_tensor_desc_device_type, + get_tensor_desc_device_id, +) + +from ..utils.logger import log +from ..common import * +from ..typing import ( + Boolean, + Float8E5M2, + Int64, + Int32, + Int16, + Int8, + Uint64, + Uint32, + Uint16, + Uint8, + Float64, + Float32, + Float16, + BFloat16, +) + + +class TensorDescriptor: + def __init__(self, tensor): + """Initialize with a tensor that supports the DLPack protocol. + + Args: + tensor: Any tensor object that implements __dlpack__ and __dlpack_device__ + """ + + self.tensor = tensor + self._capsule = dlpack_to_tensor_desc(tensor) + + self.data_ptr = get_tensor_desc_data_ptr(self._capsule) + self.device_type = get_tensor_desc_device_type(self._capsule) + self.device_type = _dpack.DLDeviceType(self.device_type) + + if self.device_type == _dpack.DLDeviceType.kDLGPU: + self.device_pointer = self.data_ptr + elif self.device_type == _dpack.DLDeviceType.kDLCPU: + self.device_pointer = None + else: + raise DSLRuntimeError( + f"DLPack device type is not supported {self.dl_tensor.device.device_type}" + ) + + log().info("TensorDescriptor is created = [%s]", self) + + @staticmethod + def can_transformed_to_dlpack(dl_tensor): + if not hasattr(dl_tensor, "__dlpack__") or not hasattr( + dl_tensor, "__dlpack_device__" + ): + return False + return True + + @property + def is_in_device(self): + """Check if the tensor is stored on a device.""" + return not self.device_pointer is None + + @property + def device_id(self): + """Return device id where tensor resides.""" + if self.is_in_device: + return get_tensor_desc_device_id(self._capsule) + return -1 + + @property + def element_type(self): + """Return the corresponding Python type based on DLPack dtype metadata.""" + str_element_type = get_tensor_desc_element_type(self._capsule) + dtype_map = { + # bool is 8bit from numpy and torch + "Bool": Boolean, + "Int64": Int64, + "Int32": Int32, + "Int16": Int16, + "Int8": Int8, + "UInt64": Uint64, + "UInt32": Uint32, + "UInt16": Uint16, + "UInt8": Uint8, + "Float64": Float64, + "Float32": Float32, + "Float16": Float16, + "BFloat16": BFloat16, + "Float8E5M2": Float8E5M2, + } + + if str_element_type not in dtype_map: + raise KeyError( + f"Unsupported element type in dlpack: '{str_element_type}'. Supported types are: {list(dtype_map.keys())}" + ) + + return dtype_map[str_element_type] + + @property + def shape(self): + """Return the shape of the tensor.""" + return get_tensor_desc_shape(self._capsule) + + @property + def rank(self): + """Return the rank of the tensor.""" + return get_tensor_desc_ndim(self._capsule) + + @property + def strides(self): + """Return the rank of the tensor.""" + return get_tensor_desc_stride(self._capsule) + + @property + def element_size_in_bytes(self): + """Calculate the element size in bytes of the DLPack tensor.""" + return get_tensor_desc_element_size_in_bytes(self._capsule) + + @property + def size_in_bytes(self): + """Calculate the total size in bytes of the DLPack tensor.""" + # Calculate the number of elements using the shape + ndim = get_tensor_desc_ndim(self._capsule) + shape = get_tensor_desc_shape(self._capsule) + num_elements = 1 + for i in range(ndim): + num_elements *= shape[i] + + # Total bytes + total_bytes = self.element_size_in_bytes * num_elements + return total_bytes + + def __str__(self): + """Return a compact string representation of the device_tensor with a tensor prefix.""" + # Extract shape + shape = "x".join(map(str, self.shape)) + + # Extract dtype + dtype_code = get_tensor_desc_dtype_code(self._capsule) + dtype_bits = get_tensor_desc_dtype_bits(self._capsule) + dtype = ( + f"i{dtype_bits}" + if dtype_code == _dpack.DLDataTypeCode.kDLInt + else f"f{dtype_bits}" + ) + + # Extract device + device_type = "cpu" if not self.is_in_device else "gpu" + + return f"tensor<{shape}x{dtype}>_{device_type}" + + def _check_is_managed_by_framework(self): + """ + Ensure the tensor is not managed by the framework (e.g., GPU tensor). + Raises an exception if the tensor is framework-managed. + """ + return self.device_type == _dpack.DLDeviceType.kDLGPU + + @staticmethod + def is_compatible(maybe_tensor_descriptor) -> bool: + """Check if the object is a TensorDescriptor or can be converted to one.""" + return isinstance( + maybe_tensor_descriptor, TensorDescriptor + ) or TensorDescriptor.can_transformed_to_dlpack(maybe_tensor_descriptor) + + +def from_tensor(tensor) -> TensorDescriptor: + """Create a TensorDescriptor from a tensor object.""" + return TensorDescriptor(tensor) + + +def to_tensor(tensor_descriptor: TensorDescriptor): + """Return tensor object from tensor descriptor.""" + return tensor_descriptor.tensor diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..b46cff6de8176217f38af05b8604716c34aae009 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/typing.py @@ -0,0 +1,1962 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import ctypes +import numpy as np +import operator +from typing_extensions import deprecated +from functools import reduce +from typing import ( + Generic, + Protocol, + Union, + Any, + List, + Type, + TypeVar, + overload, + runtime_checkable, + get_origin, +) +from types import FunctionType +from dataclasses import dataclass +from abc import ABC, abstractmethod + +from .common import * +from .ast_helpers import const_expr +from ._mlir_helpers import arith as arith_helper, lru_cache_ir +from ._mlir_helpers.arith import ArithValue + +from .._mlir import ir +from .._mlir.extras import types as T +from .._mlir.dialects import arith, math + +# ============================================================================= +# Dynamic Expression Protocol +# ============================================================================= + + +@runtime_checkable +class DynamicExpression(Protocol): + """Protocol defining the interface for object holding dynamic values in the DSL. + + This protocol enables classes to represent dynamic values in the DSL. Classes implementing + this protocol can be used in JIT-compiled functions and dynamic value generation. + + It is required for custom data types to work correctly with following JIT features: + * as function argument to call another JIT function from JIT function + * as return value from JIT function + * for constructions like if-else, while-loop, etc. + + :param value: The MLIR operation result value to initialize the object with + :type value: ir.Value + + **Required Methods** + + * ``__extract_mlir_values__``: Extract MLIR values from the object + * ``__new_from_mlir_values__``: Create new instance from MLIR values + + **Implementation Example** + + To implement a custom data type that works with the DSL: + + .. code-block:: python + + class CustomData(metaclass=DslType): + def __init__(self, int_value): + self.int_value = int_value + + def __extract_mlir_values__(self): + return [self.int_value] + + def __new_from_mlir_values__(self, values): + return CustomData(values[0]) + + **Usage in JIT Functions** + + When used in JIT-compiled functions, the DSL automatically extracts MLIR values: + + .. code-block:: python + + @jit + def caller(): + x = CustomData(1) + return foo(x) + + This generates MLIR like: + + .. code-block:: mlir + + func @caller() -> i32 { + %0 = func.call @foo(%arg0) : (i32) -> i32 + return %0 : i32 + } + """ + + def __extract_mlir_values__(self): + """Extract MLIR values from this object. + + :return: List of MLIR values representing this object's data + :rtype: List[ir.Value] + """ + raise NotImplementedError + + def __new_from_mlir_values__(self, values): + """Create a new instance from MLIR values. + + :param values: List of MLIR values to construct the object from + :type values: List[ir.Value] + :return: New instance of the implementing class + :rtype: Any + """ + raise NotImplementedError + + +@runtime_checkable +class JitArgument(Protocol): + """ + Protocol class defining the interface for JIT function argument generation. + + This protocol enables classes to provide the necessary information for generating + JIT function arguments and allow the DSL JIT executor to call JIT compiled functions. + + **Required Methods** + + * ``__c_pointers__``: Returns ctypes pointers for runtime execution + * ``__get_mlir_types__``: Returns MLIR types for function definition + * ``__new_from_mlir_values__``: Creates new instances from MLIR values + + **Example** + + .. code-block:: python + + class CustomData: + def __init__(self, int_value, ...): + self.int_value = int_value + ... + + def __c_pointers__(self): + return [ctypes.pointer(ctypes.c_int32(self.int_value)), ...] + + def __get_mlir_types__(self): + return [ir.IntegerType.get(32), ...] + + def __new_from_mlir_values__(self, values): + return CustomData(values[0], ...) + + @jit + def foo(x: CustomData): + a = x.int_value + 1 + ... + + # `CustomData` is an argument of `foo` + foo(CustomData(1, ...)) + + When called like ``y = foo(x)``, the following steps occur: + + 1. JIT compiler generates MLIR function definition using ``__get_mlir_types__`` + + .. code-block:: mlir + + func.func @foo(%arg0: i32, ...) { + ... + + return + } + + 2. JIT function can't use values from Python, so it needs to reconstruct the object from + MLIR values, a.k.a `%arg0`, with ``__new_from_mlir_values__`` and pass it to `foo`. + + Following code demonstrates how JIT compiler reconstructs the object and pass to Python. + + .. code-block:: python + + # Implementation of IR tracing + new_x = CustomData(ir.Value(%arg0), ...) + y = foo(new_x) + # `x.int_value` is %arg0 rather than `c1` defined by Python. + + 3. For Python runtime execution, JIT engine invokes compiled function using ``__c_pointers__`` + pointing to the underlying data object passing to JIT compiled function. + + .. code-block:: python + + jit_engine.invoke(compiled_foo, concat([x.__c_pointers__(), ...])) + """ + + def __c_pointers__(self): + """ + Generate a list of ctypes pointers for the current object. + + :return: List of ctypes pointers + :rtype: List[ctypes.c_void_p] + """ + raise NotImplementedError + + def __get_mlir_types__(self): + """ + Generate a list of MLIR types for the current object. + + :return: List of MLIR types + :rtype: List[ir.Type] + """ + raise NotImplementedError + + def __new_from_mlir_values__(self, values): + """ + Create a new object from MLIR values. + + :param values: List of MLIR values + :type values: List[ir.Value] + :return: A new object that represents the given MLIR values + :rtype: Any + """ + raise NotImplementedError + + +def get_c_pointers(obj): + """ + Given the `obj`, recursively go through it to extract all contained C pointers + """ + if hasattr(obj, "__c_pointers__"): + return obj.__c_pointers__() + elif isinstance(obj, (tuple, list)): + return sum((get_c_pointers(x) for x in obj), []) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in get_c_pointers to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + return [] + + +def get_mlir_types(obj): + """ + Given the `obj`, recursively go through it to extract all contained MLIR types + """ + if hasattr(obj, "__get_mlir_types__"): + return obj.__get_mlir_types__() + elif hasattr(obj, "__extract_mlir_values__"): + return [v.type for v in obj.__extract_mlir_values__()] + elif isinstance(obj, ir.Value): + return [obj.type] + elif isinstance(obj, (tuple, list)): + return sum((get_mlir_types(x) for x in obj), []) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in get_mlir_types to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + return [] + + +class DslType(type): + """Metaclass for all DSL types in the system. + + This metaclass provides type system infrastructure for DSL types, handling MLIR + type mappings and NumPy type conversions. + + All data types in DSL must provide the following methods: + + :param mlir_type: Corresponding MLIR type for this DSL type + :type mlir_type: Any, optional + :param is_abstract: Whether this type is abstract, defaults to False + :type is_abstract: bool, optional + + **Required Methods** + + * ``__str__`` (classmethod): Return string representation of the type + * ``__c_pointers__`` (optional): Return list of ctypes pointers of data used to invoke JIT function + * ``__get_mlir_types__``: Return list of MLIR types of the MLIR values contained in the instance + * ``__extract_mlir_values__``: Return list of MLIR values contained in the instance + * ``__new_from_mlir_values__``: Return a new instance from list of MLIR values + + **Attributes** + + :ivar _ir: MLIR provider + :vartype _ir: Any + :ivar _T: MLIR Type system provider + :vartype _T: Any + + **Properties** + + :property mlir_type: Returns the corresponding MLIR type for this DSL type + :type mlir_type: Any + + """ + + _is_abstract: bool + + def __new__(cls, name, bases, attrs, is_abstract=False, **kwargs): + new_cls = super().__new__(cls, name, bases, attrs) + + new_cls._is_abstract = is_abstract + + return new_cls + + @property + def is_abstract(cls): + return cls._is_abstract + + +class NumericMeta(DslType): + """Metaclass for numeric types providing width and numpy dtype information. + + :param width: Bit width of the numeric type, defaults to 8 + :type width: int + :param np_dtype: Corresponding NumPy dtype + :type np_dtype: numpy.dtype, optional + :param mlir_type: Corresponding MLIR type + :type mlir_type: Any, optional + :param is_abstract: Whether the type is abstract, defaults to False + :type is_abstract: bool, optional + + :ivar width: Bit width of the numeric type + :type width: int + :ivar _np_dtype: Corresponding NumPy dtype + :type _np_dtype: Union[numpy.dtype, None] + + :property numpy_dtype: Returns the corresponding NumPy dtype + :rtype numpy_dtype: numpy.dtype + """ + + width: int + + # Placeholder type + _mlir_type = Any + _np_dtype: Union[np.dtype, None] + + def __new__( + cls, + name, + bases, + attrs, + width=8, + np_dtype=None, + mlir_type=None, + is_abstract=False, + **kwargs, + ): + def _extract_mlir_values(self): + return [self.ir_value()] + + def _new_from_mlir_values(self, values: list) -> "Numeric": + res_ty = type(self) + return res_ty(values[0]) + + new_attrs = { + "__extract_mlir_values__": _extract_mlir_values, + "__new_from_mlir_values__": _new_from_mlir_values, + } + new_cls = super().__new__( + cls, + name, + bases, + new_attrs | attrs, + is_abstract=is_abstract, + **kwargs, + ) + + if mlir_type is not None: + new_cls._mlir_type = staticmethod(mlir_type) + + new_cls.width = width + new_cls._np_dtype = np_dtype + return new_cls + + @property + def numpy_dtype(cls): + return cls._np_dtype + + @property + def is_integer(cls) -> bool: ... + + @property + def is_float(cls) -> bool: ... + + def is_same_kind(cls, other: Type) -> bool: + return cls.is_integer == other.is_integer or cls.is_float == other.is_float + + @staticmethod + def from_python(value: Any) -> Type["Numeric"]: + """ + Deduce the DSL type from a Python value. + """ + if isinstance(value, int): + return Int32 + elif isinstance(value, float): + return Float32 + elif isinstance(value, bool): + return Boolean + raise DSLRuntimeError( + f"Could not deduce Type[Numeric] from python value: {value} :{type(value)}" + ) + + @property + def mlir_type(cls): + return cls._mlir_type() # type: ignore + + +Value = TypeVar("Value") + + +def cast(obj: Union[bool, int, float, Value], type_: Type["Numeric"]) -> "Numeric": + """Cast an object to the specified numeric type. + + :param obj: Object to be cast + :type obj: Union[bool, int, float, Value] + :param type_: Target numeric type + :type type_: Type[Numeric] + :raises TypeError: If casting to an abstract type or unsupported type conversion + :return: Object cast to the target numeric type + :rtype: Numeric + + Example:: + >>> x = cast(5, Int32) # Cast integer to Int32 + >>> y = cast(3.14, Float32) # Cast float to Float32 + """ + if type_.is_abstract: + if not isinstance(obj, type_): + raise TypeError( + f"can't cast {obj} to {type_}. Pass in concrete type instead, " + "e.g. Int32, Float32, etc." + ) + # If target_type is abstract, and value is instance of target_type, + # then we can return value as is + else: + # Implicit cast based on using annotation type + obj = type_(obj) + return obj + + +# Option 1: use ir.Value as base +# class IntegerMeta(DslType, type(ir.Value)): +class IntegerMeta(NumericMeta): + """Metaclass for integer types providing signedness information. + + :param width: Bit width of the integer type, defaults to 32 + :type width: int + :param signed: Whether the integer type is signed, defaults to True + :type signed: bool + :param mlir_type: Corresponding MLIR type, defaults to None + :type mlir_type: Any, optional + + :ivar signed: Whether the integer type is signed + :vartype signed: bool + :ivar arith: Arithmetic operations interface + :vartype arith: Any + """ + + signed: bool + + def __new__( + cls, + name, + bases, + attrs, + width=32, + signed=True, + mlir_type=None, + is_abstract=False, + ): + if width == 1: + np_dtype = np.bool_ + elif width == 128: + np_dtype = None + elif signed: + np_dtype = getattr(np, f"int{width}") + else: + np_dtype = getattr(np, f"uint{width}") + + def _c_pointers(self): + if width == 1: + c_value = ctypes.c_bool(self.value) + elif signed: + c_value = getattr(ctypes, f"c_int{width}")(self.value) + else: + c_value = getattr(ctypes, f"c_uint{width}")(self.value) + + return [ctypes.cast(ctypes.pointer(c_value), ctypes.c_void_p)] + + new_attrs = { + "__c_pointers__": _c_pointers, + } + new_cls = super().__new__( + cls, name, bases, attrs | new_attrs, width, np_dtype, mlir_type, is_abstract + ) + new_cls.signed = signed + return new_cls + + def __str__(cls): + return f"{cls.__name__}" + + @property + def is_integer(cls) -> bool: + return True + + @property + def is_float(cls) -> bool: + return False + + @property + def zero(cls) -> int: + return 0 + + @property + def min(cls) -> int: + if cls.signed: + return -(2 ** (cls.width - 1)) + else: + return 0 + + @property + def max(cls) -> int: + if cls.signed: + return 2 ** (cls.width - 1) - 1 + else: + return 2**cls.width - 1 + + def recast_width(cls, width): + type_map = { + 8: Int8, + 16: Int16, + 32: Int32, + 64: Int64, + 128: Int128, + } + if width not in type_map: + raise TypeError(f"Unsupported width: {width}") + return type_map[width] + + +class FloatMeta(NumericMeta): + """Metaclass for floating-point types. + + This metaclass provides type system infrastructure for floating-point types in the DSL, + handling MLIR type mappings and NumPy type conversions. + + :param width: Bit width of the float type, defaults to 32 + :type width: int + :param mlir_type: Corresponding MLIR type, defaults to None + :type mlir_type: Any, optional + :param is_abstract: Whether this is an abstract base class, defaults to False + :type is_abstract: bool, optional + + :ivar _arith: Arithmetic operations interface + :vartype _arith: Any + """ + + _exponent_width: int + _mantissa_width: int + + def __new__(cls, name, bases, attrs, width=32, mlir_type=None, is_abstract=False): + np_dtype = getattr(np, name.lower(), None) + new_cls = super().__new__( + cls, name, bases, attrs, width, np_dtype, mlir_type, is_abstract + ) + # Extract exponent and mantissa bits from class name if it follows Float pattern + # For example: Float8E4M3 -> exponent_width=4, mantissa_width=3 + import re + + if not is_abstract: + match = re.match(r"Float(\d+)E(\d+)M(\d+)(?:.*)", name) + if match: + exp_bits = int(match.group(2)) + mant_bits = int(match.group(3)) + + # Store extracted values as class attributes + new_cls._exponent_width = exp_bits + new_cls._mantissa_width = mant_bits + # Don't have 1-to-1 mapping of narrow precision types like bfloat16, tfloat32, etc. + return new_cls + + def __str__(cls): + return f"{cls.__name__}" + + @property + def is_integer(cls) -> bool: + return False + + @property + def is_float(cls) -> bool: + return True + + @property + def zero(cls) -> float: + return 0.0 + + @property + def inf(cls) -> float: + return float("inf") + + @property + def nan(cls) -> float: + return float("nan") + + @property + def exponent_width(cls) -> int: + return cls._exponent_width + + @property + def mantissa_width(cls) -> int: + return cls._mantissa_width + + def recast_width(cls, width): + type_map = { + 16: Float16, + 32: Float32, + 64: Float64, + } + if width not in type_map: + raise TypeError(f"Unsupported width: {width}") + return type_map[width] + + +def _arith_signless_to_int(a, target_type): + # is_signed: sign of result type + if target_type.width > a.type.width: + # arith dialect consider `1` in `i1` as `-1`, treat it as unsigned for DSL + if target_type.signed and a.type.width > 1: + return arith.extsi(target_type.mlir_type, a) + else: + return arith.extui(target_type.mlir_type, a) + elif target_type.width < a.type.width: + return arith.trunci(target_type.mlir_type, a) + else: + return a + + +def _binary_op_type_promote(a, b, promote_bool: bool = False): + """Promote two numeric operands following type promotion rules. + + :param a: First numeric operand + :type a: Numeric + :param b: Second numeric operand + :type b: Numeric + :param promote_bool: Whether to promote boolean types to Int32 for arithmetic operations, defaults to False + :type promote_bool: bool, optional + :raises ValueError: If implicit float promotion is not supported between the given types + :return: Tuple containing promoted operands and their resulting type + :rtype: tuple[Numeric, Numeric, Type[Numeric]] + + Type promotion rules: + 1. If operands are same type and not bools needing promotion: + - No promotion needed, return original types + 2. If either operand is float: + a. If one is float and one is int: + - Convert int to the float type + b. If both are float: + - Promote to higher precision float if width >= 16 + - For same width, promote to more general type (Float32 over TFloat32) + - Otherwise raise ValueError for unsupported promotion + 3. Otherwise, both operands are integers. Integer promotion rules: + a. If promote_bool is True and either operand is bool: + - Promote bool to Int32 for arithmetic operations + + Exceptions for numpy dtype casting: + - array(dtype=np.bool_) + array(dtype=np.bool_) -> array(dtype=np.bool_) + + What is not supported: + - promotion with narrow precision float types which requires explicit cast by user + """ + a_type = a.dtype + b_type = b.dtype + + # Early return for same types (except when they're bools that need promotion) + if a_type == b_type and not (promote_bool and a_type is Boolean): + return a, b, a_type + + # Handle floating point promotions + if a_type.is_float or b_type.is_float: + # Get highest precision float type based on bitwidth + a_width = getattr(a_type, "width", 0) + b_width = getattr(b_type, "width", 0) + + # If one type is integer, convert it to the float type + if a_type.is_float and not b_type.is_float: + b_type = a_type.recast_width(max(a_width, b_width)) + elif b_type.is_float and not a_type.is_float: + a_type = b_type.recast_width(max(a_width, b_width)) + + # Both are float types - handle precision promotion + if a_width > b_width and a_width >= 16: + res_type = a_type + elif b_width > a_width and b_width >= 16: + res_type = b_type + elif a_width == b_width: + # Same bitwidth - handle special cases like TFloat32 -> Float32 and BFloat16 -> Float16 + if a_type is Float64 or b_type is Float64: + res_type = Float64 + elif a_type is Float32 or b_type is Float32: + res_type = Float32 + elif a_type is Float16 or b_type is Float16: + res_type = Float16 + else: + raise ValueError( + f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly" + ) + else: + raise ValueError( + f"implicit float promotion of {a_type} or {b_type} is not supported, cast explicitly" + ) + + # Only convert if type is different + new_a = a.to(res_type) if a.dtype != res_type else a + new_b = b.to(res_type) if b.dtype != res_type else b + return new_a, new_b, res_type + + # Handle bool promotion for arithmetic operations + if promote_bool: + if a_type is Boolean and b_type is Boolean: + # Only promote to Int32 when both are bool + a = a.to(Int32) + b = b.to(Int32) + a_type = b_type = a.dtype + + # If both were bools, they're now same type (Int32) + if a_type == b_type: + return a, b, a_type + + # Same type, no promotion needed + if a_type == b_type: + return a, b, a_type + + a_signed = a_type.signed + b_signed = b_type.signed + a_width = a_type.width + b_width = b_type.width + + # Mixed signedness case + if a_signed != b_signed: + unsigned_type = a_type if not a_signed else b_type + signed_type = a_type if a_signed else b_type + unsigned_width = a_width if not a_signed else b_width + + if unsigned_width >= signed_type.width: + # Promote both to unsigned of larger width + res_type = unsigned_type + else: + # Promote both to signed of larger width + res_type = signed_type + + new_a = a.to(res_type) if a.dtype != res_type else a + new_b = b.to(res_type) if b.dtype != res_type else b + return new_a, new_b, res_type + + # Same signedness, different width - promote to larger width + if a_width >= b_width: + return a, b.to(a.dtype), a.dtype + else: + return a.to(b.dtype), b, b.dtype + + +def _binary_op(op, promote_operand=True, promote_bool=False, flip=False): + """Wrapper for binary operations on Numeric types. + + This wrapper handles type promotion, operation execution, and result type determination + for binary operations between Numeric types. + + :param op: The binary operation to perform (e.g., operator.add, operator.sub) + :type op: callable + :param emitter: Function that emits the MLIR operation for dynamic values + :type emitter: callable + :param promote_operand: Whether to promote operands to the same type, defaults to True + :type promote_operand: bool, optional + :param promote_bool: Whether to promote boolean results to Boolean type, defaults to False + :type promote_bool: bool, optional + :param flip: Whether to flip the operands when calling the operation, defaults to False + :type flip: bool, optional + + :raises TypeError: When an unsupported operation is attempted on specific numeric types + + .. note:: + Not all operations are supported for all numeric types. In particular: + + - Subtraction is not fully supported for Integer types + - Multiplication, floor division, and modulo operations may have limited support + - Division (truediv) with integer types is not fully supported and converts to Float32 + """ + + def wrapper(lhs, rhs, *, loc=None, ip=None): + orig_lhs_type = type(lhs) + orig_rhs_type = type(rhs) + + # When called directly with self and other + ty = type(lhs) + # Canonicalize to Numeric type for promotion + if not isinstance(rhs, Numeric): + if not isinstance(rhs, (ArithValue, int, float, bool)): + # This allows rhs class to implement __rmul__ + return NotImplemented + + if isinstance(rhs, ArithValue): + if isinstance(rhs.type, ir.VectorType): + return NotImplemented + + rhs = as_numeric(rhs) + + # default result type to left-hand-side + res_type = ty + + if promote_operand: + lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool) + else: + rhs = ty(rhs) + + if op in ( + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne, + ): + res_type = Boolean + elif op == operator.truediv and isinstance(lhs, Integer): + res_type = Float32 + elif promote_bool and orig_lhs_type == Boolean and orig_rhs_type == Boolean: + res_type = Boolean + + if isinstance(lhs.value, ArithValue) and isinstance(lhs, Integer): + lhs_val = lhs.value.with_signedness(lhs.signed) + else: + lhs_val = lhs.value + + if isinstance(rhs.value, ArithValue) and isinstance(rhs, Integer): + rhs_val = rhs.value.with_signedness(rhs.signed) + else: + rhs_val = rhs.value + + if flip: + lhs_val, rhs_val = rhs_val, lhs_val + + # Check if the operation is supported by the operands + res_val = op(lhs_val, rhs_val) + return res_type(res_val, loc=loc, ip=ip) + + return wrapper + + +class Numeric(metaclass=NumericMeta, is_abstract=True): + """Base class for all numeric types in the DSL. + + This class provides the foundation for both Integer and Float types, + implementing basic arithmetic operations. + + :param value: The value to store in the numeric type + :type value: Union[bool, int, float, Value] + + :ivar value: The stored numeric value + :vartype value: Union[bool, int, float, Value] + """ + + def __init__(self, value: Union[bool, int, float, Value], *, loc=None, ip=None): + self.value = value + + def __str__(self) -> str: + # Use member's pretty-str method if member object has method. + # This can be extended in future to have better support for IDE, jupyter notebook, etc. + pretty_str = getattr(self.value, "pretty_str", None) + if pretty_str is not None: + return pretty_str() + else: + return "?" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({repr(self.value)})" + + def __hash__(self): + return hash(type(self).__class__) ^ hash(self.value) + + @property + def dtype(self) -> Type["Numeric"]: + return type(self) + + @overload + def to(self, dtype: Type["Numeric"], *, loc=None, ip=None) -> "Numeric": ... + + @overload + def to(self, dtype: Type[int], *, loc=None, ip=None) -> int: ... + + @overload + def to(self, dtype: Type[float], *, loc=None, ip=None) -> float: ... + + @overload + def to(self, dtype: Type[bool], *, loc=None, ip=None) -> bool: ... + + @overload + def to(self, dtype: Type[ir.Value], *, loc=None, ip=None) -> ir.Value: ... + + def to(self, dtype: Type, *, loc=None, ip=None): + """Convert this numeric value to another numeric type. + + If the target type is the same as the current type, returns self. + Otherwise, creates a new instance of the target type with the same value. + + :param dtype: The target numeric type to convert to + :type dtype: Union[Type["Numeric"], Type[int], Type[float], Type[bool]] + :return: A new instance of the target type, or self if types match + :rtype: Numeric + :raises TypeError: If trying to convert an MLIR value to a static Python type + :raises TypeError: If trying to convert to unsupported float types like Float8E4M3, + Float8E4M3B11FNUZ, Float4E2M1FN, Float6E3M2FN, or Float6E2M3FN + + .. note:: + + Unsupported destination float types: + - Float8E4M3 + - Float8E4M3B11FNUZ + - Float4E2M1FN + - Float6E3M2FN + - Float6E2M3FN + + Example:: + + .. code-block:: python + + # Convert between DSL numeric types + x = Int32(5) + y = x.to(Float32) # Converts to Float32(5.0) + + # Convert to Python primitive types + # They are considered as static values at JIT time + z = x.to(int) # Returns Python int 5 + w = y.to(float) # Returns Python float 5.0 + + # This will raise a ValueError + mlir_val = arith.constant(T.i32(), 42) + num = Int32(mlir_val) + num.to(int) # ValueError: unable to convert MLIR value to static type: + """ + if dtype in _unsupported_dst_float_types: + raise TypeError(f"Unsupported destination float type: {dtype}") + + if isinstance(dtype, type(self)): + return self + elif isinstance(dtype, NumericMeta): + return dtype(self) + elif dtype is ir.Value: + if isinstance(self.value, (int, float, bool)): + res = arith_helper.const( + self.value, self.dtype.mlir_type, loc=loc, ip=ip + ) + elif isinstance(self.value, ir.Value): + res = self.value + else: + raise ValueError( + f"cannot convert {type(self)} to {dtype}, " + f"self.value is {self.value.type}" + ) + + if not isinstance(res, ArithValue): + raise ValueError(f"Expected ArithValue, got {type(res)} as {res.type}") + + return res.with_signedness(getattr(type(self), "signed", None)) + elif dtype in (int, float, bool): + if isinstance(self.value, ir.Value): + raise ValueError( + f"unable to convert {self.value} to static type: {dtype}" + ) + return dtype(self.value) + else: + raise ValueError(f"unable to convert {type(self)} to {dtype}") + + def ir_value(self, *, loc=None, ip=None) -> ir.Value: + return self.to(ir.Value, loc=loc, ip=ip) + + @property + def zero(self) -> "Numeric": ... + + def __dsl_not__(self, *, loc=None, ip=None): + """DSL implementation of Python's `not` operator. + + Returns True if the value is equal to zero, False otherwise. + This matches Python's behavior where any non-zero number is considered True. + + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The result of the logical not operation + :rtype: Boolean + """ + if isinstance(self.value, (int, float, bool)): + return not self.value + else: + ty = type(self) + zero_val = arith.constant(ty.mlir_type, ty.zero) + return self.__eq__(ty(zero_val), loc=loc, ip=ip) + + def __dsl_and__(self, other, *, loc=None, ip=None): + """DSL implementation of Python's `and` operator. + + Returns the second operand if the first is truthy, otherwise returns the first operand. + A numeric value is considered truthy if it is non-zero. + + :param other: The right-hand operand + :type other: Numeric + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The result of the logical and operation + :rtype: Boolean + + Example:: + + 5 and 3 -> 3 + 0 and 3 -> 0 + 3 and 0 and ... -> 0 + """ + is_true = self.__dsl_bool__(loc=loc, ip=ip) + + def and_op(lhs, rhs): + if isinstance(lhs, (int, float, bool)): + if isinstance(rhs, (int, float, bool)): + return lhs and rhs + else: + lhs = arith.constant(rhs.type, lhs) + return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) + else: + if isinstance(rhs, (int, float, bool)): + rhs = arith.constant(lhs.type, rhs) + return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) + else: + return arith.select(is_true.ir_value(), rhs, lhs, loc=loc, ip=ip) + + return _binary_op(and_op, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __dsl_or__(self, other, *, loc=None, ip=None): + """DSL implementation of Python's `or` operator. + + Returns the first operand if it is truthy, otherwise returns the second operand. + A numeric value is considered truthy if it is non-zero. + + :param other: The right-hand operand + :type other: Numeric + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The result of the logical or operation + :rtype: Boolean + + Example:: + + 5 or 3 -> 5 + 0 or 3 -> 3 + 3 or 0 -> 3 + """ + is_true = self.__dsl_bool__(loc=loc, ip=ip) + + def or_op(lhs, rhs): + if isinstance(lhs, (int, float, bool)): + if isinstance(rhs, (int, float, bool)): + return lhs or rhs + else: + lhs = arith.constant(rhs.type, lhs) + return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) + else: + if isinstance(rhs, (int, float, bool)): + rhs = arith.constant(lhs.type, rhs) + return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) + else: + return arith.select(is_true.ir_value(), lhs, rhs, loc=loc, ip=ip) + + return _binary_op(or_op, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __dsl_bool__(self, *, loc=None, ip=None) -> "Boolean": + """DSL implementation of Python's __bool__ method. + + Returns a Boolean indicating whether this value is considered truthy. + For numeric types, returns True if the value is non-zero. + + :param loc: The source location information, defaults to None + :type loc: Optional[Location] + :param ip: The insertion point for the operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: True if this value is truthy (non-zero), False otherwise + :rtype: Boolean + """ + zero = type(self).zero + return self.__ne__(zero, loc=loc, ip=ip) + + def __bool__(self): + if isinstance(self.value, (int, float, bool)): + return bool(self.value) + else: + raise DSLRuntimeError( + f"Unable to convert dynamic `{type(self).__name__}` value to bool at compile time.", + suggestion=[ + "Decorate the parent function with `jit` decorator and with `preprocess` enabled.", + "Ensure not using patterns that DSL does not support.", + "Otherwise, please file a bug report.", + ], + ) + + def __index__(self): + if isinstance(self.value, (int, float, bool)): + return self.value + else: + raise DSLRuntimeError( + f"'{type(self.value)}' object cannot be interpreted as an integer", + suggestion="Mark the loop as dynamic with `dynamic_expr` or `range_dynamic` and decorate the parent function with `jit` decorator", + ) + + def __neg__(self, *, loc=None, ip=None): + if isinstance(self, (bool, int, float)): + return type(self)(-self.value) # type: ignore + else: + return type(self)(-self.value, loc=loc, ip=ip) # type: ignore + + @staticmethod + def _from_python_value(value): + if isinstance(value, Numeric): + return value + + if isinstance(value, bool): + res_type = Boolean + elif isinstance(value, int): + res_type = Int32 + elif isinstance(value, float): + res_type = Float32 + elif isinstance(value, ArithValue): + res_type = Numeric.from_mlir_type(value.type) + else: + raise ValueError( + f"unable to convert {value} in type {type(value)} to Numeric" + ) + return res_type(value) + + def __add__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.add, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __sub__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.sub, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __mul__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.mul, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __floordiv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.floordiv, promote_bool=True)( + self, other, loc=loc, ip=ip + ) + + def __truediv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.truediv, promote_bool=True)( + self, other, loc=loc, ip=ip + ) + + def __mod__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.mod, promote_bool=True)(self, other, loc=loc, ip=ip) + + def __radd__(self, other, *, loc=None, ip=None) -> "Numeric": + return self.__add__(other, loc=loc, ip=ip) + + def __rsub__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.sub, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __rmul__(self, other, *, loc=None, ip=None) -> "Numeric": + return self.__mul__(other, loc=loc, ip=ip) + + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.floordiv, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __rtruediv__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.truediv, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __rmod__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.mod, promote_bool=True, flip=True)( + self, other, loc=loc, ip=ip + ) + + def __eq__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.eq)(self, other, loc=loc, ip=ip) # type: ignore + + def __ne__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.ne)(self, other, loc=loc, ip=ip) # type: ignore + + def __lt__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.lt)(self, other, loc=loc, ip=ip) # type: ignore + + def __le__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.le)(self, other, loc=loc, ip=ip) # type: ignore + + def __gt__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.gt)(self, other, loc=loc, ip=ip) # type: ignore + + def __ge__(self, other, *, loc=None, ip=None) -> "Boolean": + return _binary_op(operator.ge)(self, other, loc=loc, ip=ip) # type: ignore + + def __pow__(self, other, *, loc=None, ip=None) -> "Numeric": + return _binary_op(operator.pow)(self, other, loc=loc, ip=ip) # type: ignore + + def __c_pointers__(self): + raise ValueError( + f"only support built-in types: bool, (u)int{8, 16, 32, 64}, float{32, 64}, but got {type(self)}" + ) + + def __get_mlir_types__(self): + return [type(self).mlir_type] + + @staticmethod + def from_mlir_type(mlir_type): + type_map = { + T.bool(): Boolean, + T.f64(): Float64, + T.f32(): Float32, + T.tf32(): TFloat32, + T.f16(): Float16, + T.bf16(): BFloat16, + T.i(128): Int128, + T.i64(): Int64, + T.i32(): Int32, + T.i16(): Int16, + T.i8(): Int8, + T.si(128): Int128, + T.si64(): Int64, + T.si32(): Int32, + T.si16(): Int16, + T.si8(): Int8, + T.ui(128): Uint128, + T.ui64(): Uint64, + T.ui32(): Uint32, + T.ui16(): Uint16, + T.ui8(): Uint8, + T.f8E5M2(): Float8E5M2, + T.f8E4M3(): Float8E4M3, + T.f8E4M3FN(): Float8E4M3FN, + T.f8E4M3B11FNUZ(): Float8E4M3B11FNUZ, + T.f4E2M1FN(): Float4E2M1FN, + T.f6E2M3FN(): Float6E2M3FN, + T.f6E3M2FN(): Float6E3M2FN, + T.f8E8M0FNU(): Float8E8M0FNU, + } + if mlir_type not in type_map: + raise DSLRuntimeError(f"Unsupported DSL type: {mlir_type}") + return type_map[mlir_type] + + +def as_numeric(obj: Union[bool, int, float, ir.Value, Numeric]) -> Numeric: + """Convert a Python primitive value to a Numeric type. + + :param obj: Python primitive value to convert + :type obj: Union[bool, int, float] + :return: The converted Numeric object + :rtype: Numeric + + Example:: + + .. code-block:: python + + x = as_numeric(5) # Converts to Int32 + y = as_numeric(3.14) # Converts to Float32 + z = as_numeric(True) # Converts to Boolean + """ + if isinstance(obj, Numeric): + return obj + return Numeric._from_python_value(obj) + + +class Integer(Numeric, metaclass=IntegerMeta, mlir_type=T.i32, is_abstract=True): + """A class representing integer values with specific width and signedness. + + This class provides functionality to create and manipulate integer values with + configurable width and signedness. It supports conversion from various input types + including Python scalars, MLIR Values, and other numeric types. + + :param x: The input value to convert to this integer type + :type x: Union[bool, int, float, ir.Value, Integer, Float] + + :return: A new Integer instance with the converted value + :rtype: Integer + + :raises AssertionError: If the type's numpy_dtype is None + :raises NotImplementedError: If converting between different Integer types + :raises ValueError: If the input type is not supported for conversion + :raises OverflowError: If converting float infinity to integer + + Type conversion behavior: + + * Python scalars (bool, int, float): + * Converted through numpy dtype casting + * NaN and infinity values are rejected + * Example: Int8(256) -> -256 (overflow behavior) + + * MLIR Value with IntegerType: + * Width differences handled by signless to signed/unsigned conversion + * Example: i8 -> i8/ui8 depending on target type + + * MLIR Value with FloatType: + * Uses MLIR float-to-int conversion + * NaN and infinity values is undefined behavior + * Example: f32 -> i32/ui32 depending on target type + + * Integer: + * Uses MLIR float-to-int conversion or numpy dtype casting + * Example: Int32(Int32(5)) => 5 + + * Float: + * Uses MLIR float-to-int conversion + * Example: Int32(Float(5.7)) -> 5 + + Example usage: + + .. code-block:: python + + x = Int32(5) # From integer + y = Int32(True) # From boolean + z = Int32(3.7) # From float (truncates) + w = Int32(x) # From same Integer type + c5 = arith.constant(5, T.i32()) + a = Int32(c5) # Treat c5 as int32 bitwise + """ + + def __init__(self, x, *, loc=None, ip=None): + ty = type(self) + + if isinstance(x, (bool, int, float)): + # Add check for NaN before numpy conversion + if isinstance(x, float): + if np.isnan(x): + raise ValueError("Cannot convert float NaN to integer") + elif np.isinf(x): + raise OverflowError("Cannot convert float infinity to integer") + + np_dtype = ty.numpy_dtype + assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" + x_val = int(np.array(x).astype(np_dtype)) + elif type(x) == ty: + x_val = x.value + elif isinstance(x, ir.Value): # type: ignore + x_val = x + if isinstance(x.type, ir.IntegerType): # type: ignore + if x.type.width != ty.width: + # signless -> (u)int + x_val = _arith_signless_to_int(x, ty) + elif isinstance(x.type, ir.FloatType): # type: ignore + # float -> (u)int + x_val = arith_helper.fptoi(x, ty.signed, ty.mlir_type, loc=loc, ip=ip) + elif isinstance(x, Integer): + if isinstance(x.value, ir.Value): + x_val = arith_helper.int_to_int(x.ir_value(), ty) + else: + # For non-MLIR values, use numpy casting + src_val = np.array(x.value, dtype=type(x).numpy_dtype) + x_val = int(src_val.astype(ty.numpy_dtype)) + elif isinstance(x, Float): + # float -> int is handled by Integer.__init__ recursively + Integer.__init__(self, x.value) + return + else: + raise DSLRuntimeError(f"{x} to integer conversion is not supported") + + super().__init__(x_val) + + def __invert__(self, *, loc=None, ip=None): + res_type = type(self) + return res_type(self.ir_value(loc=loc, ip=ip).__invert__(loc=loc, ip=ip)) + + def __lshift__(self, other, *, loc=None, ip=None): + return _binary_op(operator.lshift)(self, other, loc=loc, ip=ip) + + def __rlshift__(self, other, *, loc=None, ip=None): + other_ = as_numeric(other) + if not isinstance(other_, Integer): + raise ValueError(f"Cannot left shift {other_} with {self}") + return other_.__lshift__(self, loc=loc, ip=ip) + + def __rshift__(self, other, *, loc=None, ip=None): + return _binary_op(operator.rshift)(self, other, loc=loc, ip=ip) + + def __rrshift__(self, other, *, loc=None, ip=None): + other_ = as_numeric(other) + if not isinstance(other_, Integer): + raise ValueError(f"Cannot right shift {other_} with {self}") + return other_.__rshift__(self, loc=loc, ip=ip) + + def __and__(self, other, *, loc=None, ip=None): + return _binary_op(operator.and_)(self, other, loc=loc, ip=ip) + + def __rand__(self, other, *, loc=None, ip=None): + return self.__and__(other, loc=loc, ip=ip) + + def __or__(self, other, *, loc=None, ip=None): + return _binary_op(operator.or_)(self, other, loc=loc, ip=ip) + + def __ror__(self, other, *, loc=None, ip=None): + return self.__or__(other, loc=loc, ip=ip) + + def __xor__(self, other, *, loc=None, ip=None): + return _binary_op(operator.xor)(self, other, loc=loc, ip=ip) + + def __rxor__(self, other, *, loc=None, ip=None): + return self.__xor__(other, loc=loc, ip=ip) + + +class Float(Numeric, metaclass=FloatMeta, mlir_type=T.f32, is_abstract=True): + """A class representing floating-point values. + + :param x: The input value to convert to this float type. + :type x: Union[bool, int, float, ir.Value, Integer, Float] + + Type conversion behavior: + + 1. Python scalars (bool, int, float): + - Converted through numpy dtype casting + - Example: Float32(1.7) -> 1.7 + + 2. MLIR Value with FloatType: + - If width differs: converts between float types + - Example: f16 -> f32 + + 3. MLIR Value with IntegerType: + - Not supported, raises ValueError + + 4. Integer: + - Converts using MLIR int-to-float operation + - Example: Float32(Int32(5)) -> 5.0 + + 5. Float: + - Direct conversion between float types + - Example: Float32(Float32(1.5)) -> 1.5 + + .. note:: + The following narrow precision types are only supported in device code: + + 8-bit float types: + - Float8E5M2 + - Float8E4M3 + - Float8E4M3FN + - Float8E8M0FNU + - Float8E4M3B11FNUZ + + 6-bit float types: + - Float6E3M2FN + - Float6E2M3FN + + 4-bit float types: + - Float4E2M1FN + + Narrow precision types and special floating-point formats support matrix on device: + + :raises AssertionError: If the type's numpy_dtype is None + :raises ValueError: If conversion from the input type is not supported + """ + + def __init__(self, x, *, loc=None, ip=None): + ty = type(self) + + if isinstance(x, (bool, int, float)): # type: ignore + # Why we need to convert x to with numpy? + # np_dtype = ty.numpy_dtype + # assert np_dtype is not None, f"expects numpy.dtype, but got {np_dtype}" + # x = float(np.array(x).astype(np_dtype)) + super().__init__(float(x)) + elif isinstance(x, ir.Value): # type: ignore + if isinstance(x.type, ir.IntegerType): # type: ignore + raise DSLRuntimeError("signless to float conversion is not implemented") + elif isinstance(x.type, ir.FloatType): # type: ignore + if x.type != ty.mlir_type: + x = arith_helper.cvtf(x, ty.mlir_type, loc=loc, ip=ip) + super().__init__(x) + elif isinstance(x, Integer): + if isinstance(x.value, ir.Value): # type: ignore + x = arith_helper.itofp( + x.value, type(x).signed, ty.mlir_type, loc=loc, ip=ip + ) + else: + x = float(x.value) + super().__init__(x) + elif isinstance(x, Float): + Float.__init__(self, x.value) + else: + raise DSLRuntimeError(f"{x} to Float conversion is not supported") + + +class Boolean(Integer, metaclass=IntegerMeta, width=1, signed=True, mlir_type=T.bool): + """Boolean type representation in the DSL. + + This class represents boolean values in the DSL, with a width of 1 bit. + It supports conversion from various types to boolean values. + + :param a: Value to convert to Boolean + :type a: Union[bool, int, float, "Value", Numeric] + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :raises DSLRuntimeError: If the input value cannot be converted to Boolean + + Conversion rules: + + 1. Python bool/int/float: + - Converted using Python's bool() function + - Example: Boolean(1) -> True, Boolean(0) -> False + + 2. Numeric: + - Uses the Numeric.value to construct Boolean recursively + + 3. MLIR Value with IntegerType: + - If width is 1: Direct assignment + - Otherwise: Compares with 0 using arith.cmpi + + 4. MLIR Value with FloatType: + - Compares with 0.0 using arith.cmpf + - Uses unordered comparison to handle NaN values + """ + + def __init__( + self, a: Union[bool, int, float, ir.Value, Numeric], *, loc=None, ip=None + ): + value = None + if isinstance(a, (bool, int, float)): + value = bool(a) + elif isinstance(a, Numeric): + Boolean.__init__(self, a.value, loc=loc, ip=ip) + return + elif isinstance(a, ArithValue): + if a.type == T.bool(): + value = a + else: + value = a != arith_helper.const(0, a.type, loc=loc, ip=ip) + if value is None: + raise DSLRuntimeError(f"Cannot convert {a} to Boolean") + super().__init__(value, loc=loc, ip=ip) + self._value_int8 = None + + def ir_value_int8(self, *, loc=None, ip=None): + """ + Returns int8 ir value of Boolean. + When we need to store Boolean tensor element, use ir_value_int8(). + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :return: The int8 value of this Boolean + :rtype: ir.Value + """ + if self._value_int8 is not None: + return self._value_int8 + self._value_int8 = Int8(self.value, loc=loc, ip=ip).ir_value() + return self._value_int8 + + def __neg__(self, *, loc=None, ip=None): + """Negation operator is not supported for boolean type. + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :raises TypeError: Always raises this error as negation is not supported + """ + raise TypeError("Negation, the operator `-` is not supported for boolean type") + + +class Int8(Integer, metaclass=IntegerMeta, width=8, signed=True, mlir_type=T.i8): ... + + +class Int16(Integer, metaclass=IntegerMeta, width=16, signed=True, mlir_type=T.i16): ... + + +class Int32(Integer, metaclass=IntegerMeta, width=32, signed=True, mlir_type=T.i32): ... + + +class Int64(Integer, metaclass=IntegerMeta, width=64, signed=True, mlir_type=T.i64): ... + + +class Int128( + Integer, metaclass=IntegerMeta, width=128, signed=True, mlir_type=lambda: T.i(128) +): ... + + +class Uint8(Integer, metaclass=IntegerMeta, width=8, signed=False, mlir_type=T.i8): ... + + +class Uint16( + Integer, metaclass=IntegerMeta, width=16, signed=False, mlir_type=T.i16 +): ... + + +class Uint32( + Integer, metaclass=IntegerMeta, width=32, signed=False, mlir_type=T.i32 +): ... + + +class Uint64( + Integer, metaclass=IntegerMeta, width=64, signed=False, mlir_type=T.i64 +): ... + + +class Uint128( + Integer, metaclass=IntegerMeta, width=128, signed=False, mlir_type=lambda: T.i(128) +): ... + + +class Float64(Float, metaclass=FloatMeta, width=64, mlir_type=T.f64): + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + + return [ + ctypes.cast(ctypes.pointer(ctypes.c_double(self.value)), ctypes.c_void_p) + ] + + +class Float32(Float, metaclass=FloatMeta, width=32, mlir_type=T.f32): + @staticmethod + def _get_c_pointer(value: float): + return ctypes.cast(ctypes.pointer(ctypes.c_float(value)), ctypes.c_void_p) + + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + + return [Float32._get_c_pointer(self.value)] + + +class TFloat32(Float, metaclass=FloatMeta, width=32, mlir_type=T.tf32): + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + return [Float32._get_c_pointer(self.value)] + + +class Float16(Float, metaclass=FloatMeta, width=16, mlir_type=T.f16): + @staticmethod + def _get_c_pointer(value: float): + # Convert float to float16 binary representation + # First convert to numpy float16 to handle the conversion + f16_val = np.float16(value) + # Get the raw bits as a 16-bit integer + bits = f16_val.view(np.uint16) + # Create a short (16-bit int) with those bits + c_val = ctypes.c_short(bits) + return ctypes.cast(ctypes.pointer(c_val), ctypes.c_void_p) + + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + return [Float16._get_c_pointer(self.value)] + + +class BFloat16(Float, metaclass=FloatMeta, width=16, mlir_type=T.bf16): + def __c_pointers__(self): + if not isinstance(self.value, float): + raise ValueError("only float is supported") + + return Float.__c_pointers__(self) + + +class Float8E5M2(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E5M2): ... + + +class Float8E4M3FN(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3FN): ... + + +class Float8E4M3B11FNUZ( + Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3B11FNUZ +): ... + + + +# Added missing float types +class Float8E4M3(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E4M3): ... + + +class Float8E8M0FNU(Float, metaclass=FloatMeta, width=8, mlir_type=T.f8E8M0FNU): ... + + +class Float4E2M1FN(Float, metaclass=FloatMeta, width=4, mlir_type=T.f4E2M1FN): ... + + +class Float6E3M2FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E3M2FN): ... + + +class Float6E2M3FN(Float, metaclass=FloatMeta, width=6, mlir_type=T.f6E2M3FN): ... + + +_unsupported_dst_float_types = [ + Float8E4M3, + Float8E4M3B11FNUZ, + Float4E2M1FN, + Float6E3M2FN, + Float6E2M3FN, +] + + +ALL_DTYPES = { + Int8, + Int16, + Int32, + Int64, + Int128, + Uint8, + Uint16, + Uint32, + Uint64, + Uint128, + BFloat16, + Float16, + Float32, + TFloat32, + Float64, + Float8E5M2, + Float8E4M3, + Float8E4M3FN, + Float8E8M0FNU, + Float8E4M3B11FNUZ, + Float4E2M1FN, + Float6E2M3FN, + Float6E3M2FN, +} +__STR_TO_DTYPE__ = {dt.__name__: dt for dt in ALL_DTYPES} + + +def dtype(dtype_) -> Type[Numeric]: + t = None + if const_expr(isinstance(dtype_, str) and dtype_ in __STR_TO_DTYPE__): + t = __STR_TO_DTYPE__[dtype_] + else: + raise TypeError(f"can't interpret {dtype_} as data type") + + return t + + +############################################################## +# Tensor +############################################################## + + +class TensorMeta(DslType): + _element_type = Any + _shape = Any + + """ + Examples: + >>> Tensor[Int32, (3,)] + >>> Tensor[Float32, (3, 4)] + >>> T = TypeVar("T") + >>> Tensor[T, (3, 4, 5)] + """ + + def __new__(cls, name, bases, attrs, element_type=Any, shape=Any): + new_cls = super().__new__(cls, name, bases, attrs) + new_cls._element_type = element_type + new_cls._shape = shape + return new_cls + + +# Generic type +TY = TypeVar("TY") + + +class Constexpr(Generic[TY]): + """Value is passed and computed by python interpreter""" + + pass + + +class align: + def __init__(self, value: int): + if value <= 0 or (value & (value - 1)) != 0: + raise DSLRuntimeError("expects align be power of 2 as positive value") + self._value = value + + def __str__(self): + return f"align({self._value})" + + +class PointerMeta(DslType): + def __new__(cls, name, bases, attrs, value_type=Int32, align_=align(1)): + new_cls = super().__new__( + cls, + name, + bases, + attrs, + mlir_type=lambda: getattr(ir, "UnrankedMemRefType").get( + value_type.mlir_type, getattr(ir, "Attribute").parse("0") + ), + ) + new_cls._value_type = value_type + new_cls._align = align_ + return new_cls + + def __eq__(cls, other): + if not isinstance(other, PointerMeta): + return False + return ( + cls._value_type == other._value_type + and cls._align._value == other._align._value + ) # Compare alignment values + + def __hash__(cls): + return hash((cls._value_type, cls._align._value)) # Hash alignment value + + def __getitem__(cls, params) -> Type["Pointer"]: + value_type, align_ = params + + if not isinstance(align_, align): + raise DSLRuntimeError(f"expects align but got {align_}") + + # Create new class with proper name and parameters + new_cls = type( + f"Pointer[{value_type.__name__}, {align_}]", + (Pointer,), + {}, + value_type=value_type, + align_=align_, # Pass alignment to __new__ + ) + return new_cls + + def __str__(cls): + return f"ptr<{cls._value_type}, {cls._align}>" + + +class Pointer(metaclass=PointerMeta): + """ + A pointer to a memory location. + + Examples: + + def foo(a : Pointer[Int32, align=8]): + ... + + """ + + def __init__(self, value): + self.value = value + + def __str__(self): + return f"{self.value} : {type(self)}" + + +class IRConst(Generic[TY]): + """Value is passed as MLIR constant value for (arith.constant).""" + + def __init__(self, ty: TY): + self.ty = ty + + +class IRValue(Generic[TY]): + """Value is passed as MLIR dynamic value.""" + + def __init__(self, ty: TY): + self.ty = ty + + +class IRVariadic: + """ + A helper class to pass a variadic number of arguments to a function. + """ + + def __init__(self, operands): + """ + Create a list of variadic operands. `operands` must be dynamic values. + """ + self.operands = operands + + def block_arg_types(self): + """ + Return the list of block args types. + """ + return [operand.type for operand in self.operands] + + def set_func_args(self, block_args): + """ + This function is called after entering a function. `block_args` are the + block arguments that correspond to the passed operands. Derived classes + may implement this function to provide convenience getters for block + arguments. + """ + pass + + def __len__(self): + """ + Return the length of variadic operands. + """ + return len(self.operands) + + +class FuncArgWithAttr(IRValue): + """ + This derived class is specifically for func op arg with attr + """ + + def __init__(self, ty, attr_name, attr_ty, attr_value=None): + super().__init__(ty) + assert attr_name is not None and ( + attr_ty is not None or attr_value is not None + ), "Invalid attr_name and/or attr_ty and/or attr_value for FuncArgWithAttr" + self.attr_name = attr_name + self.attr_ty = attr_ty + self.attr_value = attr_value + + + +def implicitDowncastNumericType(value): + if isinstance(value, Numeric): + return value.ir_value() + return value + + +__all__ = [ + "DslType", + "Numeric", + "NumericMeta", + "IntegerMeta", + "FloatMeta", + "Boolean", + "Integer", + "Int16", + "Int32", + "Int64", + "Int128", + "Int8", + "Uint8", + "Uint16", + "Uint32", + "Uint64", + "Uint128", + "Float", + "Float16", + "BFloat16", + "TFloat32", + "Float32", + "Float64", + "Float8E5M2", + "Float8E4M3", + "Float8E4M3FN", + "Float8E4M3B11FNUZ", + "Float8E4M3", + "Float8E8M0FNU", + "Float4E2M1FN", + "Float6E2M3FN", + "Float6E3M2FN", + "as_numeric", + "align", + "Pointer", + "dtype", + "Constexpr", + "IRConst", + "IRValue", + "IRVariadic", + "implicitDowncastNumericType", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bfb2b7d91ee72b04a89de59e7dfbdec2be646c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from . import stacktrace +from . import logger +from . import timer +__all__ = [ + "logger", + "timer", + "stacktrace", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e4b4edf359ec86b6b5806cb0b2296f9cb918f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/logger.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides logging helper functions +""" + +import logging + +logger = None + + +def log(): + return logger + + +def setup_log( + name, log_to_console=False, log_to_file=False, log_file_path=None, log_level=1 +): + """Set up and configure a logger with console and/or file handlers. + + :param name: Name of the logger to create + :type name: str + :param log_to_console: Whether to enable logging to console, defaults to False + :type log_to_console: bool, optional + :param log_to_file: Whether to enable logging to file, defaults to False + :type log_to_file: bool, optional + :param log_file_path: Path to the log file, required if log_to_file is True + :type log_file_path: str, optional + :param log_level: Logging level to set, defaults to 1 + :type log_level: int, optional + :raises ValueError: If log_to_file is True but log_file_path is not provided + :return: Configured logger instance + :rtype: logging.Logger + """ + # Create a custom logger + global logger + logger = logging.getLogger(name) + if log_to_console or log_to_file: + logger.setLevel(log_level) + else: + # Makes sure logging is OFF + logger.setLevel(logging.CRITICAL + 1) + + # Clear existing handlers to prevent duplicate logs + if logger.hasHandlers(): + logger.handlers.clear() + + # Define formatter + formatter = logging.Formatter( + f"%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s] - %(message)s" + ) + + # Add console handler if enabled + if log_to_console: + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + # Add file handler if enabled + if log_to_file: + if not log_file_path: + raise ValueError("log_file_path must be provided when enable_file is True") + file_handler = logging.FileHandler(log_file_path) + file_handler.setLevel(log_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + return logger + + +logger = setup_log("generic") diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py new file mode 100644 index 0000000000000000000000000000000000000000..d2091098c173e8a941ed7958802dfbdee24199bc --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/stacktrace.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" + This module provides stacktrace helper functions +""" + +import os +import re + + +def walk_to_top_module(start_path): + """ + Walk up from the start_path to find the top-level Python module. + + :param start_path: The path to start from. + :return: The path of the top-level module. + """ + current_path = start_path + + while True: + # Check if we are at the root directory + if os.path.dirname(current_path) == current_path: + break + + # Check for __init__.py + init_file_path = os.path.join(current_path, "__init__.py") + if os.path.isfile(init_file_path): + # If __init__.py exists, move up one level + current_path = os.path.dirname(current_path) + else: + # If no __init__.py, we are not in a module; stop + break + + # If we reached the root without finding a module, return None + if os.path.dirname(current_path) == current_path and not os.path.isfile( + os.path.join(current_path, "__init__.py") + ): + return None + + # Return the path of the top-level module + return current_path + + +def _filter_internal_frames(traceback, internal_path): + """ + Filter out stack frames from the traceback that belong to the specified module path. + + This function removes stack frames from the traceback whose file paths start with + the given prefix_path, effectively hiding internal implementation details from + the error traceback shown to users. + """ + iter_prev = None + iter_tb = traceback + while iter_tb is not None: + if os.path.abspath(iter_tb.tb_frame.f_code.co_filename).startswith( + internal_path + ): + if iter_tb.tb_next: + if iter_prev: + iter_prev.tb_next = iter_tb.tb_next + else: + traceback = iter_tb.tb_next + else: + iter_prev = iter_tb + iter_tb = iter_tb.tb_next + return traceback + + +_generated_function_names = re.compile( + r"^(loop_body|while_region|while_before_block|while_after_block|if_region|then_block|else_block|elif_region)_\d+$" +) + + +def _filter_duplicated_frames(traceback): + """ + Filter out duplicated stack frames from the traceback. + The function filters out consecutive frames that are in the same file and have the same line number. + In a sequence of consecutive frames, the logic prefers to keep the non-generated frame or the last frame. + """ + iter_prev = None + iter_tb = traceback + while iter_tb is not None: + skip_current = False + skip_next = False + if iter_tb.tb_next: + current_filename = os.path.abspath(iter_tb.tb_frame.f_code.co_filename) + next_filename = os.path.abspath(iter_tb.tb_next.tb_frame.f_code.co_filename) + # if in the same file, check if the line number is the same + if current_filename == next_filename: + current_lineno = iter_tb.tb_lineno + next_lineno = iter_tb.tb_next.tb_lineno + if current_lineno == next_lineno: + # Same file and line number, check name, if current is generated, skip current, otherwise skip next + name = iter_tb.tb_frame.f_code.co_name + is_generated = bool(_generated_function_names.match(name)) + if is_generated: + # Skip current + skip_current = True + else: + # Skip next if it's generated, otherwise keep both + next_name = iter_tb.tb_next.tb_frame.f_code.co_name + skip_next = bool(_generated_function_names.match(next_name)) + if skip_current: + if iter_prev: + iter_prev.tb_next = iter_tb.tb_next + else: + traceback = iter_tb.tb_next + elif skip_next: + # if next is last frame, don't skip + if iter_tb.tb_next.tb_next: + iter_tb.tb_next = iter_tb.tb_next.tb_next + iter_prev = iter_tb + else: + iter_prev = iter_tb + iter_tb = iter_tb.tb_next + + return traceback + + +def filter_stackframe(traceback, prefix_path): + """ + Filter out stack frames from the traceback that belong to the specified module path. + + This function removes stack frames from the traceback whose file paths start with + the given prefix_path, effectively hiding internal implementation details from + the error traceback shown to users. + + :param traceback: The traceback object to filter. + :param prefix_path: The path prefix to filter out from the traceback. + :return: The filtered traceback with internal frames removed. + """ + # Step 1: filter internal frames + traceback = _filter_internal_frames(traceback, prefix_path) + + # Step 2: consolidate duplicated frames + return _filter_duplicated_frames(traceback) + + +def filter_exception(value, module_dir): + """ + Filter out internal implementation details from exception traceback. + + This function recursively processes an exception and its cause chain, + removing stack frames that belong to the specified module directory. + This helps to present cleaner error messages to users by hiding + implementation details. + + :param value: The exception object to filter. + :param module_dir: The module directory path to filter out from tracebacks. + :return: The filtered exception with internal frames removed. + """ + if hasattr(value, "__cause__") and value.__cause__: + filter_exception(value.__cause__, module_dir) + + if hasattr(value, "__traceback__"): + filter_stackframe(value.__traceback__, module_dir) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py new file mode 100644 index 0000000000000000000000000000000000000000..f41d3f7410c0227ff1b1f8df4b8ce14557cf649b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/base_dsl/utils/timer.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a timing helper functions +""" +from functools import wraps + +from .logger import log + + +# TODO: revisit this part when mlir timing manager is ready for pybind. +def timer(*dargs, **kwargs): + enable = kwargs.get("enable", True) + + def decorator(func): + @wraps(func) + def func_wrapper(*args, **kwargs): + if not enable: + return func(*args, **kwargs) + from time import time + + start = time() + result = func(*args, **kwargs) + end = time() + + # Convert time from seconds to us + spend_us = (end - start) * 1e6 + + # Determine the function type and format the log message + if hasattr(func, "__name__"): + func_name = func.__name__ + log_message = f"[JIT-TIMER] Function: {func_name} | Execution Time: {spend_us:.2f} µs" + elif "CFunctionType" in str(type(func)): + log_message = f"[JIT-TIMER] C API Function: {str(func)} | Execution Time: {spend_us:.2f} µs" + else: + log_message = f"[JIT-TIMER] Anonymous Function | Execution Time: {spend_us:.2f} µs" + + log().info(log_message) + + return result + + return func_wrapper + + if len(dargs) == 1 and callable(dargs[0]): + return decorator(dargs[0]) + else: + return decorator diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c7ed2607675990ad9579fa06b25935b2ccb46e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/__init__.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .cutlass_dsl import ( + Constexpr, + as_numeric, + min, + max, + and_, + or_, + all_, + any_, + not_, + all_, + any_, + select_, + # Control-flow without AST pre-processor + if_generate, + for_generate, + LoopUnroll, + while_generate, + yield_out, + # Control-flow with AST pre-processor + range_constexpr, + range_dynamic, + const_expr, + dynamic_expr, + # Data types + dtype, # Provides conversions to types inheriting from NumericType + DSLRuntimeError, + JitArgAdapterRegistry, + # Construction utilities for user-defined classes + extract_mlir_values, + new_from_mlir_values, +) + +from .cute.typing import * + +# Utilities not belonging to CuTe +from . import utils as utils + +# Used as internal symbol +from . import cutlass_dsl as _dsl + +# Aliases +LaunchConfig = _dsl.BaseDSL.LaunchConfig +register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter +gpu = _dsl.cutlass_gpu +cuda = _dsl.cuda_helpers + +CACHE_FILE = "compiled_cache.db" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8702ed9163837925057b48f9aafd11cffbb26a7e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/__init__.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +# Use the auto-generated enum AddressSpace +from cutlass._mlir.dialects.cute import AddressSpace + +# Explicitly import types that might be directly used by other modules. +# This is a fix for using Sphinx to generate documentation +# Because Sphinx processes each module in isolation, it won't be able to rely +# on re-exported symbols via wildcard imports (from .typing import *) in the +# same way that Python does at runtime. +from .typing import ( + Shape, + Stride, + IntTuple, + Coord, + Tile, + XTuple, + Tiler, + Layout, + Pointer, + Tensor, +) + +# Import everything else +from .typing import * + +from .core import ( + assume, + is_integer, + is_int_tuple, + is_static, + size, + has_underscore, + slice_, + make_ptr, + make_layout, + recast_layout, + make_fragment_like, + depth, + rank, + flatten_to_tuple, + flatten, + unflatten, + product, + product_like, + shape, + size_in_bytes, + make_identity_layout, + make_ordered_layout, + make_composed_layout, + make_layout_tv, + make_swizzle, + recast_ptr, + make_tensor, + make_identity_tensor, + make_fragment, + recast_tensor, + get, + select, + front, + is_major, + leading_dim, + find, + find_if, + coalesce, + group_modes, + cosize, + dice, + product_each, + prepend, + append, + prepend_ones, + append_ones, + ceil_div, + slice_and_offset, + crd2idx, + domain_offset, + elem_less, + transform_leaf, + filter_zeros, + filter, + tile_to_shape, + shape_div, + composition, + complement, + right_inverse, + left_inverse, + max_common_layout, + max_common_vector, + logical_product, + zipped_product, + tiled_product, + flat_product, + raked_product, + blocked_product, + flat_divide, + logical_divide, + zipped_divide, + tiled_divide, + local_partition, + local_tile, + printf, + print_tensor, + # tiled mma/tiled copy + make_mma_atom, + make_tiled_mma, + make_copy_atom, + make_tiled_copy_tv, + make_tiled_copy, + make_tiled_copy_S, + make_tiled_copy_D, + make_tiled_copy_A, + make_tiled_copy_B, + make_tiled_copy_C, + make_tiled_copy_C_atom, + basic_copy, + basic_copy_if, + autovec_copy, + copy, + copy_atom_call, + gemm, + # Wrapper classes + ComposedLayout, + Swizzle, + E, + Atom, + MmaAtom, + CopyAtom, + TiledCopy, + TiledMma, + TensorSSA, + ReductionOp, + full, + full_like, + empty_like, + ones_like, + zeros_like, + where, + any_, + all_, + # User defined struct + struct, + pretty_str, + make_layout_image_mask, + repeat_like, + round_up, + is_congruent, + is_weakly_congruent, + ScaledBasis, + get_divisibility, + Ratio, +) + +from . import arch +from . import nvgpu +from . import testing +from . import runtime + +# Export all math ops without "math." +from .math import * + +# Used as internal symbol +from .. import cutlass_dsl as _dsl + +# Aliases +jit = _dsl.CuTeDSL.jit +kernel = _dsl.CuTeDSL.kernel +register_jit_arg_adapter = _dsl.JitArgAdapterRegistry.register_jit_arg_adapter +compile = _dsl.compile + +# Explicitly export all symbols for documentation generation +__all__ = [ + # Core types + "AddressSpace", + "Tensor", + "Layout", + "ComposedLayout", + "Swizzle", + "E", + "Atom", + "MmaAtom", + "CopyAtom", + "TiledCopy", + "TiledMma", + "TensorSSA", + # Basic utility functions + "assume", + "is_integer", + "is_int_tuple", + "is_static", + "size", + "has_underscore", + "slice_", + "depth", + "rank", + "shape", + "printf", + "print_tensor", + "pretty_str", + # Layout functions + "make_layout", + "recast_layout", + "make_identity_layout", + "make_ordered_layout", + "make_composed_layout", + "make_layout_tv", + "make_layout_image_mask", + # Tensor functions + "make_ptr", + "make_tensor", + "make_identity_tensor", + "make_fragment", + "make_fragment_like", + "recast_ptr", + "recast_tensor", + # Tensor manipulation + "get", + "select", + "front", + "is_major", + "leading_dim", + "find", + "find_if", + "coalesce", + "group_modes", + "cosize", + "size_in_bytes", + # Tuple operations + "flatten_to_tuple", + "flatten", + "product", + "product_like", + "product_each", + "prepend", + "append", + "prepend_ones", + "append_ones", + # Math operations + "ceil_div", + "round_up", + # Layout operations + "slice_and_offset", + "crd2idx", + "domain_offset", + "elem_less", + "filter_zeros", + "filter", + "tile_to_shape", + "shape_div", + "dice", + # Layout algebra + "composition", + "complement", + "right_inverse", + "left_inverse", + "max_common_layout", + "max_common_vector", + "is_congruent", + "is_weakly_congruent", + # Product operations + "logical_product", + "zipped_product", + "tiled_product", + "flat_product", + "raked_product", + "blocked_product", + # Division operations + "flat_divide", + "logical_divide", + "zipped_divide", + "tiled_divide", + "local_partition", + "local_tile", + # MMA and Copy operations + "make_mma_atom", + "make_tiled_mma", + "make_copy_atom", + "make_tiled_copy_tv", + "make_tiled_copy", + "make_tiled_copy_C_atom", + "basic_copy", + "basic_copy_if", + "autovec_copy", + "copy", + "copy_atom_call", + "gemm", + # Tensor creation + "full", + "full_like", + "empty_like", + "ones_like", + "zeros_like", + "where", + "any_", + "all_", + "repeat_like", + "ScaledBasis", + # User defined struct + "struct", + # Modules + "arch", + "nvgpu", + "testing", + "runtime", + # Decorators and code generation + "jit", + "kernel", + "register_jit_arg_adapter", + "compile", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..01198215f74b07f224b1d5e53ff37075775bb201 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/__init__.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .elect import * +from .mbar import * +from .nvvm_wrappers import * +from .smem import * +from .tmem import * + +# __all__ is required here for documentation generation +__all__ = [ + # + # elect.py + # + "make_warp_uniform", + "elect_one", + # + # mbar.py + # + "mbarrier_init", + "mbarrier_init_fence", + "mbarrier_arrive_and_expect_tx", + "mbarrier_expect_tx", + "mbarrier_wait", + "mbarrier_try_wait", + "mbarrier_conditional_try_wait", + "mbarrier_arrive", + # + # nvvm_wrappers.py + # + "lane_idx", + "warp_idx", + "thread_idx", + "block_dim", + "block_idx", + "grid_dim", + "cluster_idx", + "cluster_dim", + "block_in_cluster_idx", + "block_in_cluster_dim", + "block_idx_in_cluster", + "shuffle_sync", + "shuffle_sync_up", + "shuffle_sync_down", + "shuffle_sync_bfly", + "barrier", + "barrier_arrive", + "sync_threads", + "sync_warp", + "fence_acq_rel_cta", + "fence_acq_rel_cluster", + "fence_acq_rel_gpu", + "fence_acq_rel_sys", + "cp_async_commit_group", + "cp_async_wait_group", + "cp_async_bulk_commit_group", + "cp_async_bulk_wait_group", + "cluster_wait", + "cluster_arrive", + "cluster_arrive_relaxed", + "fence_proxy", + "vote_ballot_sync", + "popc", + "fence_view_async_tmem_load", + "fence_view_async_tmem_store", + "warpgroup_reg_alloc", + "warpgroup_reg_dealloc", + "fma_packed_f32x2", + "mul_packed_f32x2", + "add_packed_f32x2", + "fmax", + "rcp_approx", + "exp2", + # Constants + "WARP_SIZE", + # Forward from auto-generated nvvm python + "ProxyKind", + "SharedSpace", + "RoundingModeKind", + # + # smem.py + # + "alloc_smem", + "get_dyn_smem", + "get_dyn_smem_size", + # + # tmem.py + # + "retrieve_tmem_ptr", + "alloc_tmem", + "relinquish_tmem_alloc_permit", + "dealloc_tmem", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py new file mode 100644 index 0000000000000000000000000000000000000000..ead552afab7de50a62f95eee7b4d8a2d9b4dfca9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/elect.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from cutlass.cutlass_dsl import CuTeDSL, T, dsl_user_op + +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import nvvm, scf +from cutlass._mlir import ir + +from ..typing import Int, Int32 +from ...impl_utils import check_value_in + + +@dsl_user_op +def make_warp_uniform(value: Int, *, loc=None, ip=None) -> Int32: + """ + Creates a warp-uniform value from the given integer input. + + :param value: The integer to make warp uniform. + :type value: Int + :return: The warp-uniform value equal to the input. + :rtype: Int32 + """ + return Int32( + _cute_nvgpu_ir.arch_make_warp_uniform( + Int32(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + + +class IfOpRegion: + """ + A context manager for if Op. + Automatically inserts `scf.yield([])` when exiting the context. + """ + + def __init__(self, block, *, loc=None, ip=None): + self.block = block + self.insert_point = ir.InsertionPoint(self.block) + self.loc = loc + self.ip = ip + + def __enter__(self): + self.insert_point.__enter__() + return self.block.arguments + + def __exit__(self, exc_type, exc_value, traceback): + scf.yield_([], loc=self.loc, ip=self.ip) + self.insert_point.__exit__(exc_type, exc_value, traceback) + + +@dsl_user_op +def elect_one(*, loc=None, ip=None) -> IfOpRegion: + """ + Elects one thread within a warp. + + .. code-block:: python + + with elect_one(): + # Only one thread in the warp executes the code in this context + pass + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + is_thread_leader = nvvm.elect_sync(T.bool()) + if_op = scf.IfOp(is_thread_leader, loc=loc, ip=ip) + return IfOpRegion(if_op.then_block, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py new file mode 100644 index 0000000000000000000000000000000000000000..80cb7b0b5fc6e226a39d68197382cbde2e32861d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/mbar.py @@ -0,0 +1,349 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. +from typing import Optional + +from cutlass.cutlass_dsl import CuTeDSL, T, if_generate, dsl_user_op + +from cutlass._mlir.dialects import nvvm +from cutlass._mlir import ir + +from ..typing import Pointer, Int, Boolean, Int32 +from ...impl_utils import check_value_in + + +#################################################################################################### +# +# Mbarrier management utilities +# +#################################################################################################### + + +@dsl_user_op +def mbarrier_init(mbar_ptr: Pointer, cnt: Int, *, loc=None, ip=None) -> None: + """ + Initializes a mbarrier with the specified thread arrival count. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param cnt: The arrival count of the mbarrier + :type cnt: Int + """ + nvvm.mbarrier_init_shared( + mbar_ptr.llvm_ptr, Int32(cnt).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + +@dsl_user_op +def mbarrier_init_fence(*, loc=None, ip=None) -> None: + """ + A fence operation that applies to the mbarrier initializations. + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + nvvm.fence_mbarrier_init(loc=loc, ip=ip) + + +@dsl_user_op +def mbarrier_arrive_and_expect_tx( + mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None +) -> None: + """ + Arrives on a mbarrier and expects a specified number of transaction bytes. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param bytes: The number of transaction bytes + :type bytes: Int + :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to + the mbarrier is converted to a remote address in the peer CTA's + SMEM. + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + if peer_cta_rank_in_cluster is not None: + mbar_llvm_ptr = nvvm.mapa_shared_cluster( + mbar_llvm_ptr.type, + mbar_llvm_ptr, + Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + space = nvvm.MBarrierSpaceKind.CLUSTER + else: + space = nvvm.MBarrierSpaceKind.CTA + + nvvm.mbarrier_txn( + mbar_llvm_ptr, + Int32(bytes).ir_value(loc=loc, ip=ip), + kind=nvvm.MBarrierTxnKind.ARRIVE_EXPECT_TX, + space=space, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_expect_tx( + mbar_ptr: Pointer, bytes: Int, peer_cta_rank_in_cluster=None, *, loc=None, ip=None +) -> None: + """ + Expects a specified number of transaction bytes without an arrive. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param bytes: The number of transaction bytes + :type bytes: Int + :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to + the mbarrier is converted to a remote address in the peer CTA's + SMEM. + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + if peer_cta_rank_in_cluster is not None: + mbar_llvm_ptr = nvvm.mapa( + mbar_llvm_ptr.type, + mbar_llvm_ptr, + Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + space = nvvm.MBarrierSpaceKind.CLUSTER + else: + space = nvvm.MBarrierSpaceKind.CTA + + nvvm.mbarrier_txn( + mbar_llvm_ptr, + Int32(bytes).ir_value(loc=loc, ip=ip), + kind=nvvm.MBarrierTxnKind.EXPECT_TX, + space=space, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> None: + """ + Waits on a mbarrier with a specified phase. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param phase: The phase to wait for (either 0 or 1) + :type phase: Int + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + timeout_ns = 10000000 + # This NVVM Op is a spin-loop wrapping the mbarrier.try_wait.parity.shared.b64 PTX + # The timeout in ns only applies to the latter and this call is truly blocking + nvvm.mbarrier_try_wait_parity_shared( + mbar_ptr.llvm_ptr, + Int32(phase).ir_value(loc=loc, ip=ip), + Int32(timeout_ns).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def mbarrier_try_wait(mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None) -> Boolean: + """ + Attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param phase: The phase to wait for (either 0 or 1) + :type phase: Int + :return: A boolean value indicating whether the wait operation was successful + :rtype: Boolean + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + return Boolean( + nvvm.mbarrier_wait_parity( + T.bool(), + mbar_ptr.llvm_ptr, + Int32(phase).ir_value(loc=loc, ip=ip), + nvvm.MBarrierWaitKind.TRY, + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def mbarrier_conditional_try_wait( + cond, mbar_ptr: Pointer, phase: Int, *, loc=None, ip=None +) -> Boolean: + """ + Conditionally attempts to wait on a mbarrier with a specified phase in a non-blocking fashion. + + :param cond: A boolean predicate + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param phase: The phase to wait for (either 0 or 1) + :type phase: Int + :return: A boolean value indicating whether the wait operation was successful + :rtype: Boolean + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + return if_generate( + cond, + lambda: mbarrier_try_wait(mbar_ptr, phase, loc=loc, ip=ip), + lambda: Boolean(True).ir_value(loc=loc, ip=ip), + None, + [Boolean], + ) + + +@dsl_user_op +def mbarrier_arrive( + mbar_ptr: Pointer, + peer_cta_rank_in_cluster: Optional[Int] = None, + *, + loc=None, + ip=None, +) -> None: + """ + Arrives on an mbarrier. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param peer_cta_rank_in_cluster: An optional CTA rank in cluster. If provided, the pointer to + the mbarrier is converted to a remote address in the peer CTA's + SMEM. + """ + mbar_llvm_ptr = mbar_ptr.llvm_ptr + if peer_cta_rank_in_cluster is not None: + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = nvvm.mapa_shared_cluster( + mbar_llvm_ptr.type, + mbar_llvm_ptr, + Int32(peer_cta_rank_in_cluster).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + space = nvvm.MBarrierSpaceKind.CLUSTER + else: + space = nvvm.MBarrierSpaceKind.CTA + + nvvm.mbarrier_txn( + mbar_llvm_ptr, + Int32(1).ir_value(loc=loc, ip=ip), + kind=nvvm.MBarrierTxnKind.ARRIVE, + space=space, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def cp_async_mbarrier_arrive_noinc(mbar_ptr: Pointer, *, loc=None, ip=None) -> None: + """ + Arrives on an mbarrier for async load **without incrementing** the arrival count + (`cp.async.mbarrier.arrive.shared ..., noinc=1`). + Used in the warp-specialized kernel when the non-TMA load warp(producer) is not the same + as the math/epilogue warp(consumer). + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + """ + arch = CuTeDSL._get_dsl().envar.arch + check_value_in( + arch, + [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ], + "arch", + ) + + mbar_llvm_ptr = mbar_ptr.llvm_ptr + nvvm.cp_async_mbarrier_arrive_shared( + mbar_llvm_ptr, + noinc=True, + loc=loc, + ip=ip, + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..69e3b8acb1fd0d1bc6615cd835235c0bbd62027b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/nvvm_wrappers.py @@ -0,0 +1,681 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from functools import partial +from typing import Optional, Tuple, Union, Callable +from typing_extensions import deprecated + +from cutlass.cutlass_dsl import T, dsl_user_op + +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm, nvvm, vector + +# Forward nvvm enums +from cutlass._mlir.dialects.nvvm import ( + ProxyKind, + SharedSpace, + Tcgen05WaitKind, + SetMaxRegisterAction, + RoundingModeKind, +) + +from ..typing import ( + Int, + Boolean, + Int16, + Uint16, + Int32, + Uint32, + Int64, + Float32, + BFloat16, + Numeric, + as_numeric, +) + +WARP_SIZE = 32 +FULL_MASK = 0xFFFFFFFF + + +@dsl_user_op +def lane_idx(*, loc=None, ip=None) -> Int32: + """ + Returns the lane index of the current thread within the warp. + """ + return Int32(nvvm.read_ptx_sreg_laneid(T.i32(), loc=loc, ip=ip)) + + +@dsl_user_op +def warp_idx(*, loc=None, ip=None) -> Int32: + """ + Returns the warp index within a CTA. + """ + warp_size = 32 + tid_x = Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)) + tid_y = Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)) + tid_z = Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)) + ntid_x = Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)) + ntid_y = Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)) + tid = tid_x + tid_y * ntid_x + tid_z * ntid_x * ntid_y + return tid // warp_size + + +@dsl_user_op +def thread_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the thread index within a CTA. + """ + return ( + Int32(nvvm.read_ptx_sreg_tid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_tid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_tid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the number of threads in each dimension of the CTA. + """ + return ( + Int32(nvvm.read_ptx_sreg_ntid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ntid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ntid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the CTA identifier within a grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_ctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_ctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def grid_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the number of CTAs in each dimension of the grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_nctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the cluster identifier within a grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_clusterid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_clusterid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_clusterid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the number of clusters in each dimension of the grid. + """ + return ( + Int32(nvvm.read_ptx_sreg_nclusterid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nclusterid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_nclusterid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_in_cluster_idx(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the CTA index within a cluster across all dimensions. + """ + return ( + Int32(nvvm.read_ptx_sreg_cluster_ctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_ctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_ctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_in_cluster_dim(*, loc=None, ip=None) -> Tuple[Int32, Int32, Int32]: + """ + Returns the dimensions of the cluster. + """ + return ( + Int32(nvvm.read_ptx_sreg_cluster_nctaid_x(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_nctaid_y(T.i32(), loc=loc, ip=ip)), + Int32(nvvm.read_ptx_sreg_cluster_nctaid_z(T.i32(), loc=loc, ip=ip)), + ) + + +@dsl_user_op +def block_idx_in_cluster(*, loc=None, ip=None) -> Int32: + """ + Returns the linearized identifier of the CTA within the cluster. + """ + return Int32(nvvm.read_ptx_sreg_cluster_ctarank(T.i32(), loc=loc, ip=ip)) + + +@dsl_user_op +def shuffle_sync_op( + value: Numeric, + offset: Int, + mask: Int = FULL_MASK, + mask_and_clamp: Int = WARP_SIZE - 1, + kind: nvvm.ShflKind = nvvm.ShflKind.idx, + *, + loc=None, + ip=None, +) -> Numeric: + """ + Shuffles a value within the threads of a warp. + + :param value: The value to shuffle + :type value: Numeric + :param mask: A mask describing the threads participating in this operation + :type mask: Int + :param offset: A source lane or a source lane offset depending on kind + :type offset: Int + :param mask_and_clamp: An integer containing two packed values specifying a mask for logically + splitting warps into sub-segments and an upper bound for clamping the + source lane index. + :type mask_and_clamp: Int + :param kind: The kind of shuffle, can be idx, up, down, or bfly + :type kind: ShflKind + :return: The shuffled value + :rtype: Numeric + """ + if not isinstance(value, Numeric): + value = as_numeric(value) + if value.width > 64: + raise ValueError("shuffle_sync only supports values up to 64 bits") + + orig_type = type(value) + if value.width < 32: + if value.dtype.is_float: + value = value.to(Float32) + else: + if value.signed: + value = value.to(Int32) + else: + value = value.to(Uint32) + return orig_type( + nvvm.shfl_sync( + type(value).mlir_type, + Int32(mask).ir_value(loc=loc, ip=ip), + value.ir_value(loc=loc, ip=ip), + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + ) + elif value.width == 32: + return orig_type( + nvvm.shfl_sync( + type(value).mlir_type, + Int32(mask).ir_value(loc=loc, ip=ip), + value.ir_value(loc=loc, ip=ip), + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + ) + else: + if value.width != 64: + raise ValueError( + "shuffle_sync only supports 64 bits values when the bit width is larger than 32" + ) + value = llvm.bitcast( + T.i64(), value.to(ir.Value, loc=loc, ip=ip), loc=loc, ip=ip + ) + # extract low 32 bits + low_32_bits = llvm.trunc( + T.i32(), value, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip + ) + # extract high 32 bits + high_32_bits = llvm.lshr( + value, Int64(32).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + high_32_bits = llvm.trunc( + T.i32(), high_32_bits, llvm.IntegerOverflowFlags.none, loc=loc, ip=ip + ) + + low_32_bits_shfl = nvvm.shfl_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + low_32_bits, + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + high_32_bits_shfl = nvvm.shfl_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + high_32_bits, + Int32(offset).ir_value(loc=loc, ip=ip), + Int32(mask_and_clamp).ir_value(loc=loc, ip=ip), + kind, + loc=loc, + ip=ip, + ) + + # combine low and high 32 bits + low_64_bit = llvm.zext(T.i64(), low_32_bits_shfl, loc=loc, ip=ip) + high_64_bit = llvm.zext(T.i64(), high_32_bits_shfl, loc=loc, ip=ip) + shlf_res = llvm.shl( + high_64_bit, + Int64(32).ir_value(loc=loc, ip=ip), + llvm.IntegerOverflowFlags.none, + loc=loc, + ip=ip, + ) + shlf_res = llvm.or_(shlf_res, low_64_bit, loc=loc, ip=ip) + shlf_res = llvm.bitcast(orig_type.mlir_type, shlf_res, loc=loc, ip=ip) + return orig_type(shlf_res) + +shuffle_sync = partial(shuffle_sync_op, kind=nvvm.ShflKind.idx) +shuffle_sync_up = partial(shuffle_sync_op, kind=nvvm.ShflKind.up) +shuffle_sync_down = partial(shuffle_sync_op, kind=nvvm.ShflKind.down) +shuffle_sync_bfly = partial(shuffle_sync_op, kind=nvvm.ShflKind.bfly) + + +@dsl_user_op +def barrier(*, barrier_id=None, number_of_threads=None, loc=None, ip=None) -> None: + """ + Creates a barrier, optionally named. + """ + if barrier_id is not None: + barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) + + if number_of_threads is not None: + number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) + + nvvm.barrier( + barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip + ) + + +@dsl_user_op +def barrier_arrive( + *, barrier_id=None, number_of_threads=None, loc=None, ip=None +) -> None: + if barrier_id is not None: + barrier_id = Int32(barrier_id).ir_value(loc=loc, ip=ip) + + if number_of_threads is None: + raise ValueError( + "barrier_arrive needs pass number_of_threads to arrive the barrier", + ) + number_of_threads = Int32(number_of_threads).ir_value(loc=loc, ip=ip) + + nvvm.barrier_arrive( + barrier_id=barrier_id, number_of_threads=number_of_threads, loc=loc, ip=ip + ) + + +@dsl_user_op +def sync_threads(*, loc=None, ip=None) -> None: + """ + Synchronizes all threads within a CTA. + """ + nvvm.barrier(loc=loc, ip=ip) + + +@dsl_user_op +def sync_warp(mask: Int = FULL_MASK, *, loc=None, ip=None) -> None: + """ + Performs a warp-wide sync with an optional mask. + """ + nvvm.bar_warp_sync(Int32(mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_cta(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_cta(loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_cluster(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_cluster(loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_gpu(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_gpu(loc=loc, ip=ip) + + +@dsl_user_op +def fence_acq_rel_sys(*, loc=None, ip=None) -> None: + """ + Fence operation with acquire-release semantics. + + See the `PTX documentation `__. + """ + nvvm.fence_acq_rel_sys(loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_commit_group(*, loc=None, ip=None) -> None: + """ + Commits all prior initiated but uncommitted cp.async instructions. + + See the `PTX documentation `__. + """ + nvvm.cp_async_commit_group(loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_wait_group(n, *, loc=None, ip=None) -> None: + """ + Waits till only a specified numbers of cp.async groups are pending. + + See the `PTX documentation `__. + """ + nvvm.cp_async_wait_group(n, loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_bulk_commit_group(*, loc=None, ip=None) -> None: + """ + Commits all prior initiated but uncommitted cp.async.bulk instructions. + + See the `PTX documentation `__. + """ + nvvm.cp_async_bulk_commit_group(loc=loc, ip=ip) + + +@dsl_user_op +def cp_async_bulk_wait_group(group, *, read=None, loc=None, ip=None) -> None: + """ + Waits till only a specified numbers of cp.async.bulk groups are pending. + + See the `PTX documentation `__. + """ + nvvm.cp_async_bulk_wait_group(group, read=read, loc=loc, ip=ip) + + +@dsl_user_op +def cluster_wait(*, loc=None, ip=None) -> None: + """ + A cluster-wide wait operation. + """ + nvvm.cluster_wait(loc=loc, ip=ip) + + +@dsl_user_op +def cluster_arrive(*, aligned=None, loc=None, ip=None) -> None: + """ + A cluster-wide arrive operation. + """ + nvvm.cluster_arrive(aligned=aligned, loc=loc, ip=ip) + + +@dsl_user_op +def cluster_arrive_relaxed(*, aligned=None, loc=None, ip=None) -> None: + """ + A cluster-wide arrive operation with relaxed semantics. + """ + nvvm.cluster_arrive_relaxed(aligned=aligned, loc=loc, ip=ip) + + +@dsl_user_op +def fence_proxy( + kind: ProxyKind, + *, + space: Optional[SharedSpace] = None, + use_intrinsic=None, + loc=None, + ip=None, +) -> None: + nvvm.fence_proxy( + kind=kind, space=space, use_intrinsic=use_intrinsic, loc=loc, ip=ip + ) + + +@dsl_user_op +def vote_ballot_sync( + pred: Boolean, mask: Int = FULL_MASK, *, loc=None, ip=None +) -> Int32: + """ + Performs a ballot operation across the warp. + """ + return Int32( + nvvm.vote_ballot_sync( + T.i32(), + Int32(mask).ir_value(loc=loc, ip=ip), + Boolean(pred).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def popc(value: Numeric, *, loc=None, ip=None) -> Numeric: + """ + Performs a population count operation. + """ + if not isinstance(value, Numeric): + value = as_numeric(value) + return type(value)(llvm.intr_ctpop(value.ir_value(), loc=loc, ip=ip)) + + +@dsl_user_op +def fence_view_async_tmem_op( + kind: Tcgen05WaitKind, + *, + loc=None, + ip=None, +) -> None: + """ + Perform a fence operation on the async TMEM load or store. + + .. note:: + This function is only available on sm_100a and above. + The fence is required to synchronize the TMEM load/store + and let the pipeline release or commit the buffer. + + Take a mma2acc pipeline as an example of LOAD fence, the ACC tensor is from TMEM. + ``` + # Start to copy ACC from TMEM to register + cute.copy(tmem_load, tACC, rACC) + fence_view_async_tmem_load() + # After fence, we can ensure the TMEM buffer is consumed totally. + # Release the buffer to let the MMA know it can overwrite the buffer. + mma2accum_pipeline.consumer_release(curr_consumer_state) + ``` + Take a TS GEMM kernel as an example of STORE fence, the A tensor is from TMEM. + ``` + # Start to copy A from register to TMEM + cute.copy(tmem_store, rA, tA) + fence_view_async_tmem_store() + # After fence, we can ensure the TMEM buffer is ready. + # Commit the buffer to let the MMA know it can start to load A. + tmem_mma_pipeline.producer_commit(curr_producer_state) + ``` + + + :param kind: The kind of fence operation to perform including LOAD and STORE. + :type kind: Tcgen05WaitKind + """ + nvvm.tcgen05_wait(kind, loc=loc, ip=ip) + + +fence_view_async_tmem_load = partial( + fence_view_async_tmem_op, kind=Tcgen05WaitKind.LOAD +) +fence_view_async_tmem_store = partial( + fence_view_async_tmem_op, kind=Tcgen05WaitKind.STORE +) + + +@dsl_user_op +def warpgroup_reg_realloc_op( + reg_count: int, + kind: SetMaxRegisterAction, + *, + loc=None, + ip=None, +) -> None: + nvvm.setmaxregister(reg_count, kind, loc=loc, ip=ip) + + +warpgroup_reg_alloc = partial( + warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.increase +) +warpgroup_reg_dealloc = partial( + warpgroup_reg_realloc_op, kind=SetMaxRegisterAction.decrease +) + + +@dsl_user_op +def calc_packed_f32x2_op( + src_a: Tuple[Float32, Float32], + src_b: Tuple[Float32, Float32], + src_c: Tuple[Float32, Float32] | None, + calc_func: Callable, + *, + rnd=RoundingModeKind.RZ, + ftz=True, + loc=None, + ip=None, +) -> Tuple[Float32, Float32]: + vec_type = ir.VectorType.get([2], Float32.mlir_type, loc=loc) + vec_src_a = vector.from_elements( + vec_type, tuple(as_numeric(a).ir_value() for a in src_a), loc=loc, ip=ip + ) + vec_src_b = vector.from_elements( + vec_type, tuple(as_numeric(b).ir_value() for b in src_b), loc=loc, ip=ip + ) + if src_c is not None: + vec_src_c = vector.from_elements( + vec_type, tuple(as_numeric(c).ir_value() for c in src_c), loc=loc, ip=ip + ) + vec_res = calc_func( + vec_type, vec_src_a, vec_src_b, vec_src_c, rnd=rnd, ftz=ftz, loc=loc, ip=ip + ) + else: + vec_res = calc_func( + vec_type, vec_src_a, vec_src_b, rnd=rnd, ftz=ftz, loc=loc, ip=ip + ) + + res0 = Float32( + vector.extract( + vec_res, dynamic_position=[], static_position=[0], loc=loc, ip=ip + ) + ) + res1 = Float32( + vector.extract( + vec_res, dynamic_position=[], static_position=[1], loc=loc, ip=ip + ) + ) + return res0, res1 + + +fma_packed_f32x2 = partial(calc_packed_f32x2_op, calc_func=nvvm.fma_packed_f32x2) +mul_packed_f32x2 = partial( + calc_packed_f32x2_op, src_c=None, calc_func=nvvm.mul_packed_f32x2 +) +add_packed_f32x2 = partial( + calc_packed_f32x2_op, src_c=None, calc_func=nvvm.add_packed_f32x2 +) + + +@dsl_user_op +def fmax( + a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None +) -> Float32: + return Float32( + nvvm.fmax( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def rcp_approx(a: Union[float, Float32], *, loc=None, ip=None): + return Float32( + nvvm.rcp_approx_ftz_f( + T.f32(), Float32(a).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + + +@dsl_user_op +@deprecated( + "cute.arch.exp2 is deprecated, use cute.math.exp2 with `fastmath=True` instead" +) +def exp2(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "ex2.approx.ftz.f32 $0, $1;", + "=f,f", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +@deprecated( + "cute.arch.exp is deprecated, use cute.math.exp with `fastmath=True` instead" +) +def exp(a: Union[float, Float32], *, loc=None, ip=None) -> Float32: + LOG2_E = 1.4426950408889634 + return exp2(a * LOG2_E, loc=loc, ip=ip) + + +@dsl_user_op +@deprecated( + "cute.arch.exp_packed_f32x2 is deprecated, use cute.arch.mul_packed_f32x2 and cute.math.exp2 with `fastmath=True` instead" +) +def exp_packed_f32x2( + a: Tuple[Float32, Float32], *, loc=None, ip=None +) -> Tuple[Float32, Float32]: + LOG2_E = Float32(1.4426950408889634) + b = mul_packed_f32x2(a, (LOG2_E, LOG2_E), loc=loc, ip=ip) + return exp2(b[0], loc=loc, ip=ip), exp2(b[1], loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py new file mode 100644 index 0000000000000000000000000000000000000000..37f87ea64d7f7482f3b2f464be6a0ee1a2e3494f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/smem.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Type + +from cutlass.cutlass_dsl import T, dsl_user_op + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..typing import Pointer, Numeric, NumericMeta + + +@dsl_user_op +def alloc_smem( + element_type: Type[Numeric], + size_in_elems: int, + alignment: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Pointer: + """ + Statically allocates SMEM. + + :param element_type: The pointee type of the pointer. + :type element_type: Type[Numeric] + :param size_in_elems: The size of the allocation in terms of number of elements of the + pointee type + :type size_in_elems: int + :param alignment: An optional pointer alignment for the allocation + :type alignment: int + :return: A pointer to the start of the allocation + :rtype: Pointer + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"element_type must be a type of Numeric, but got {element_type}" + ) + + if alignment is None: + # Default alignment based on the element type's width + alignment = element_type.width // 8 + ptr_ty = _cute_ir.PtrType.get( + element_type.mlir_type, _cute_ir.AddressSpace.smem, alignment + ) + return _cute_nvgpu_ir.arch_alloc_smem( + ptr=ptr_ty, + input=ir.IntegerAttr.get(T.i32(), size_in_elems), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def get_dyn_smem( + element_type: Type[Numeric], + alignment: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Pointer: + """ + Retrieves a pointer to a dynamic SMEM allocation. + + :param element_type: The pointee type of the pointer. + :type element_type: Type[Numeric] + :param alignment: An optional pointer alignment, the result pointer is offset appropriately + :type alignment: int + :return: A pointer to the start of the dynamic SMEM allocation with a correct + alignement + :rtype: Pointer + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"element_type must be a type of Numeric, but got {element_type}" + ) + + if alignment is None: + # Default alignment based on the element type's width + alignment = element_type.width // 8 + ptr_ty = _cute_ir.PtrType.get( + element_type.mlir_type, + _cute_ir.AddressSpace.smem, + alignment, + ) + return _cute_nvgpu_ir.arch_get_dyn_smem(ptr=ptr_ty, loc=loc, ip=ip) + + +@dsl_user_op +def get_dyn_smem_size(*, loc=None, ip=None) -> int: + """ + Gets the size in bytes of the dynamic shared memory that was specified at kernel launch time. + This can be used for bounds checking during shared memory allocation. + + :return: The size of dynamic shared memory in bytes + :rtype: int + """ + return _cute_nvgpu_ir.arch_get_dyn_smem_size(loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py new file mode 100644 index 0000000000000000000000000000000000000000..302616d20b34ccfe1d3194e48bf94114eeafeaec --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/arch/tmem.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +from ..typing import Pointer, Int, Int32, Numeric, NumericMeta + + +SM100_TMEM_CAPACITY_COLUMNS = 512 +SM100_TMEM_MIN_ALLOC_COLUMNS = 32 + + +@dsl_user_op +def retrieve_tmem_ptr( + element_type: Type[Numeric], + alignment: int, + ptr_to_buffer_holding_addr: Pointer, + *, + loc=None, + ip=None, +) -> Pointer: + """ + Retrieves a pointer to TMEM with the provided element type and alignment. + + :param element_type: The pointee type of the pointer. + :type element_type: Type[Numeric] + :param alignment: The alignment of the result pointer + :type alignment: int + :param ptr_to_buffer_holding_addr: A pointer to a SMEM buffer holding the TMEM address of the + start of the allocation allocation + :type ptr_to_buffer_holding_addr: Pointer + :return: A pointer to TMEM + :rtype: Pointer + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"element_type must be a type of Numeric, but got {element_type}" + ) + + res_ty = _cute_ir.PtrType.get( + element_type.mlir_type, _cute_ir.AddressSpace.tmem, alignment + ) + return _cute_nvgpu_ir.arch_sm100_retrieve_tmem_ptr( + res_ty, ptr_to_buffer_holding_addr.value, loc=loc, ip=ip + ) + + +@dsl_user_op +def alloc_tmem( + num_columns: Int, + smem_ptr_to_write_address: Pointer, + is_two_cta=None, + *, + loc=None, + ip=None, +) -> None: + """ + Allocates TMEM. + + :param num_columns: The number of TMEM columns to allocate + :type num_columns: Int + :param smem_ptr_to_write_address: A pointer to a SMEM buffer where the TMEM address is written + to + :type smem_ptr_to_write_address: Pointer + :param is_two_cta: Optional boolean parameter for 2-CTA MMAs + """ + if isinstance(num_columns, int): + if ( + num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS + or num_columns > SM100_TMEM_CAPACITY_COLUMNS + or not (num_columns & (num_columns - 1) == 0) + ): + raise ValueError( + f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}" + ) + _cute_nvgpu_ir.arch_sm100_alloc_tmem( + Int32(num_columns).ir_value(loc=loc, ip=ip), + smem_ptr_to_write_address.value, + is_two_cta=is_two_cta, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def relinquish_tmem_alloc_permit(is_two_cta=None, *, loc=None, ip=None) -> None: + """ + Relinquishes the right to allocate TMEM so that other CTAs potentially in a different grid can + allocate. + """ + _cute_nvgpu_ir.arch_sm100_relinquish_tmem_alloc_permit( + is_two_cta=is_two_cta, loc=loc, ip=ip + ) + + +@dsl_user_op +def dealloc_tmem( + tmem_ptr: Pointer, + num_columns: Int, + is_two_cta=None, + *, + loc=None, + ip=None, +) -> None: + """ + Deallocates TMEM using the provided pointer and number of columns. + + :param tmem_ptr: A pointer to the TMEM allocation to de-allocate + :type tmem_ptr: Pointer + :param num_columns: The number of columns in the TMEM allocation + :type num_columns: Int + :param is_two_cta: Optional boolean parameter for 2-CTA MMAs + """ + if isinstance(num_columns, int): + if ( + num_columns < SM100_TMEM_MIN_ALLOC_COLUMNS + or num_columns > SM100_TMEM_CAPACITY_COLUMNS + or not (num_columns & (num_columns - 1) == 0) + ): + raise ValueError( + f"num_columns must be between 32 and 512, and must be pow of 2, but got {num_columns}" + ) + _cute_nvgpu_ir.arch_sm100_dealloc_tmem( + tmem_ptr.value, + Int32(num_columns).ir_value(loc=loc, ip=ip), + is_two_cta=is_two_cta, + loc=loc, + ip=ip, + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py new file mode 100644 index 0000000000000000000000000000000000000000..12d5e4221a3e6007656a9400966e84d8b9a25a79 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/core.py @@ -0,0 +1,7070 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import copy as py_copy +from dataclasses import dataclass +import inspect +import math +import operator +from abc import ABC, abstractmethod +from functools import lru_cache, partial, reduce +from inspect import isclass +from itertools import chain +from typing import ( + Callable, + Iterable, + overload, + List, + Tuple, + Union, + Type, + Any, + Dict, + Optional, +) +from enum import Enum, auto + +from cutlass.cutlass_dsl import ( + const, + T, + lru_cache_ir, + is_dynamic_expression, + for_generate, + yield_out, + if_generate, + extract_mlir_values, + new_from_mlir_values, + _binary_op_type_promote, + not_, + cutlass_arith, + dsl_user_op, +) + +from cutlass._mlir import ir +from cutlass._mlir.dialects._ods_common import get_op_result_or_op_results +from cutlass._mlir.dialects import cute as _cute_ir +from cutlass._mlir.dialects.cute import ( + ScaledBasis as _ScaledBasis, + Ratio as _Ratio, +) + +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import llvm, builtin, vector, arith + +from .typing import ( + Numeric, + Integer, + NumericMeta, + Boolean, + Int32, + Int8, + Int16, + Int32, + Int64, + Float32, + TFloat32, + Int, + IntTuple, + Shape, + Stride, + Coord, + Layout, + Tile, + Tiler, + XTuple, + Tensor, + Pointer, + AddressSpace, + as_numeric, +) + + +#################################################################################################### +# +# Internal IntTuple helpers +# +#################################################################################################### + + +def _get_typed_value(x): + if isinstance(x, Integer): + return ( + x.value.get_typed_value() if isinstance(x.value, IntValue) else x.ir_value() + ) + else: + return x + + +def _pack_x(x, packer, op, *, loc=None, ip=None) -> ir.Value: + x = transform_leaf(_get_typed_value, x) + res_ty, dyn_elems = packer(x) + # <"0"> is deduced from type inference which should be removed for make_... operations + dyn_elems = [t for t in dyn_elems if not is_static(t)] + return op(res_ty, dyn_elems, loc=loc, ip=ip).result + + +def _pack_shape(shape: Shape, *, loc=None, ip=None) -> ir.Value: + _check_shape(shape) + return _pack_x(shape, _cute_ir.pack_shape, _cute_ir.MakeShapeOp, loc=loc, ip=ip) + + +def _pack_stride(stride: Stride, *, loc=None, ip=None) -> ir.Value: + _check_stride(stride) + # Convert basis elements to the base class before _pack_x + stride = transform_leaf( + lambda x: x.to(_cute_ir.ScaledBasis) if isinstance(x, ScaledBasis) else x, + stride, + ) + return _pack_x(stride, _cute_ir.pack_stride, _cute_ir.MakeStrideOp, loc=loc, ip=ip) + + +def _pack_coord(coord: Coord, *, loc=None, ip=None) -> ir.Value: + _check_coord(coord) + return _pack_x(coord, _cute_ir.pack_coord, _cute_ir.MakeCoordOp, loc=loc, ip=ip) + + +def _pack_int_tuple(int_tuple: IntTuple, *, loc=None, ip=None) -> ir.Value: + _check_int_tuple(int_tuple) + return _pack_x( + int_tuple, _cute_ir.pack_int_tuple, _cute_ir.MakeIntTupleOp, loc=loc, ip=ip + ) + + +def _pack_tile(tile: Tile, *, loc=None, ip=None) -> ir.Value: + _check_tile(tile) + + def expand_leaves(tile) -> list: + leaves = [] + for e in tile: + if isinstance(e, _Layout): + leaves.extend(list(flatten_to_tuple(e.shape))) + leaves.extend(list(flatten_to_tuple(e.stride))) + else: + leaves.append(e) + return leaves + + layout_leaves = flatten_to_tuple(tile) + dyn_elems = expand_leaves(layout_leaves) + dyn_elems = [ + _get_typed_value(x) for x in dyn_elems if isinstance(x, (Integer, ir.Value)) + ] + + res_ty = _cute_ir.pack_tile(tile) + return _cute_ir.make_tile(res_ty, dyn_elems, loc=loc, ip=ip) + + +def _unpack_x_tuple(t: Union[ir.Type, ir.Value], *, loc=None, ip=None) -> XTuple: + # If t is an MLIR type, make sure it's static and make a Value + if isinstance(t, ir.Type): + if not _cute_ir.is_static(t): + raise ValueError() + t = _cute_ir.static(t) + + if isinstance(t, ir.Value): + input_ty = t.type + if t.type.rank == 0: + # Handle this case separately, _cute_ir.get_leaves will return an Op in this case + vals = [] + else: + vals = _cute_ir.get_leaves(t, loc=loc, ip=ip) + if not isinstance(vals, list): + vals = [vals] + else: + raise TypeError(f"expects static type or value, but got {t}") + + # CuTe IR only supports Int32 for now. Need to support detection of other types + res = _cute_ir.unpack_x_tuple(input_ty, vals) + + def post_process(x): + if isinstance(x, _cute_ir.ScaledBasis): + return ScaledBasis(post_process(x.get_value()), x.get_mode()) + elif isinstance(x, _cute_ir.Ratio): + return Ratio(x.numerator, x.denominator) + else: + return x + + return transform_leaf(post_process, res) + + +#################################################################################################### +# Validation helpers +#################################################################################################### + + +def _check_shape(shape: Shape) -> None: + if is_integer(shape): + if isinstance(shape, int): + if shape <= 0: + raise ValueError( + f"Expected size in shape to be strictly positive, but got {shape}" + ) + elif isinstance(shape, Integer): + pass + else: + raise TypeError(f"Expected size be int or Integer, but got {type(shape)}") + elif isinstance(shape, tuple): + for s in shape: + _check_shape(s) + else: + raise ValueError( + f"Expected Shape, which is a positive integer or tuple of Shapes, but got {shape}" + ) + + +def _check_coord(coord: Coord) -> None: + flat_coord = flatten_to_tuple(coord) + if not all(is_integer(c) or c is None for c in flat_coord): + raise ValueError( + f"Expected Coord, whose leaves are integers or None, but got {coord}" + ) + + +def _check_stride(stride: Stride) -> None: + flat_stride = flatten_to_tuple(stride) + if not all(is_integer(s) or isinstance(s, ScaledBasis) for s in flat_stride): + raise ValueError( + f"Expected Stride, whose leaves are integers or ScaledBasis, but got {stride}" + ) + + +def _check_int_tuple(int_tuple: IntTuple) -> None: + flat_int_tuple = flatten_to_tuple(int_tuple) + if not all(is_integer(d) for d in flat_int_tuple): + raise ValueError( + f"Expected IntTuple, whose leaves are integers, but got {int_tuple}" + ) + + +def _check_tile(tile: Tile) -> None: + flat_tile = flatten_to_tuple(tile) + if not all(is_integer(t) or isinstance(t, _Layout) or t is None for t in flat_tile): + raise ValueError( + f"Expected Tile, whose leaves are integers or Layout or None, but got {tile}" + ) + + +#################################################################################################### +# +# Core types +# +#################################################################################################### + + +class IntValue(cutlass_arith.ArithValue): + """Internal representation of constrained integer types with divisibility information. + + IntValue serves as a proxy for constrained integer types in the CuTe IR. Rather than + directly storing values of IntTupleType with depth=0, it stores the result of the + `cute.get_scalars` operation applied to such values. + + This class represents the following sequence of operations in the IR: + %0 = ... : (...) -> !cute.int_tuple<"?"> + %1 = cute.get_scalars(%0) : (!cute.int_tuple<"?">) -> i32 + + where the first operation produces a `cute.int_tuple<"?">` with depth=0 and rank=1. It + automatically emit `cute.get_scalars` and track it. + + IntValue inherits behavior from ArithValue with the following extensions: + * Overloaded operations that accept IntTupleType values to propagate divisibility information + * Support for CuTe operations that utilize divisibility constraints + + API for interacting with IntValue: + * get_typed_value() - Returns the value as an IntTupleType + * get_divisibility() - Returns the divisibility constraint of the value + """ + + def __init__(self, v, signed=True): + # Cute Constrained Int Type is always signed + if isinstance(v, int): + v = _pack_int_tuple(v) + + if isinstance(v.type, _cute_ir.IntTupleType): + scalar_val = _cute_ir.get_scalars(v) + super().__init__(scalar_val, True) + else: + super().__init__(v, True) + + def get_typed_value(self): + if isinstance(self.type, ir.IntegerType): + def_op = self.owner.operation + if def_op.name == "cute.get_scalars": + return def_op.operands[0] + + assert not isinstance(self.type, _cute_ir.IntTupleType) + + return _pack_int_tuple(self) + + @property + def divisibility(self): + if isinstance(self.get_typed_value().type, _cute_ir.IntTupleType): + return self.get_typed_value().type.get_divisibility([0]) + else: + return 1 + + def __str__(self): + if self.divisibility == 1: + return f"?" + else: + return f"?{{div={self.divisibility}}}" + + def __repr__(self): + parent_name = cutlass_arith.ArithValue.__name__ + return super().__str__().replace(parent_name, IntValue.__name__) + + def pretty_str(self): + return self.__str__() + + @staticmethod + def _binary_op(op): + def wrapper(self, other, **kwargs): + if isinstance(other, IntValue): + other_val = other.get_typed_value() + elif isinstance(other, ir.Value) and isinstance( + other.type, _cute_ir.IntTupleType + ): + other_val = other + elif isinstance(other, ir.Value) and isinstance(other.type, ir.IntegerType): + other = cutlass_arith.int_to_int(other, Int32, **kwargs) + other_val = _pack_int_tuple(other) + elif isinstance(other, (int, bool)): + other_val = _pack_int_tuple(int(other)) + else: + # Dispatch to `__rmul__` of `other` + return NotImplemented + + return IntValue(op(self, other_val, **kwargs)) + + return wrapper + + @dsl_user_op + @_binary_op + def __add__(self, other, *, loc=None, ip=None): + return _cute_ir.add_offset(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __sub__(self, other, *, loc=None, ip=None): + return _cute_ir.tuple_sub(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __mul__(self, other, *, loc=None, ip=None): + return _cute_ir.tuple_mul(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __floordiv__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_div(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __mod__(self, other, *, loc=None, ip=None) -> cutlass_arith.ArithValue: + return _cute_ir.tuple_mod(self.get_typed_value(), other, loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __radd__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.add_offset(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rsub__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_sub(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rmul__(self, other, *, loc=None, ip=None): + return _cute_ir.tuple_mul(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_div(other, self.get_typed_value(), loc=loc, ip=ip) + + @dsl_user_op + @_binary_op + def __rmod__(self, other, *, loc=None, ip=None) -> "IntValue": + return _cute_ir.tuple_mod(other, self.get_typed_value(), loc=loc, ip=ip) + + +class Ratio(_Ratio): + """A class representing a rational number as a ratio of two integers. + + Ratio is used in CuTe to represent exact fractional values that arise in + tensor layout operations, particularly in composition operations where + divisibility conditions may not be satisfied. + + :param numerator: The numerator of the ratio + :type numerator: int + :param denominator: The denominator of the ratio + :type denominator: int + :raises TypeError: If numerator or denominator are not integers + """ + + def __init__(self, numerator: int, denominator: int): + if not isinstance(numerator, int) or not isinstance(denominator, int): + raise TypeError( + f"numerator and denominator must be integers, but got {numerator} and {denominator}" + ) + super().__init__(numerator, denominator) + + def is_integral(self) -> bool: + """Check if the ratio represents an integer value. + + :return: True if the numerator is divisible by the denominator + :rtype: bool + """ + return super().is_integral() + + def reduced(self) -> "Ratio": + """Return a new Ratio with the numerator and denominator reduced to lowest terms. + + :return: A new Ratio in reduced form + :rtype: Ratio + """ + res = super().reduced() + return Ratio(res.numerator, res.denominator) + + def __mul__(self, other): + """Multiply this ratio by another ratio or an integer. + + :param other: The value to multiply by + :type other: Union[Ratio, int] + :return: A new ratio representing the product + :rtype: Ratio + :raises TypeError: If other is not a Ratio or int + """ + if isinstance(other, Ratio): + return Ratio( + self.numerator * other.numerator, + self.denominator * other.denominator, + ) + elif isinstance(other, int): + return Ratio(self.numerator * other, self.denominator) + else: + raise TypeError(f"Cannot multiply Ratio with {type(other)}") + + def __rmul__(self, other): + """Right multiplication operation. + + :param other: The value to multiply by + :type other: Union[Ratio, int] + :return: A new ratio representing the product + :rtype: Ratio + """ + return self.__mul__(other) + + def __str__(self): + """String representation of the ratio. + + :return: String in the format "numerator/denominator" + :rtype: str + """ + return super().__str__() + + def to(self, dtype): + """Convert the ratio to another type. + + :param dtype: The target type for conversion + :type dtype: type + :return: The ratio converted to the specified type + :raises TypeError: If conversion to the specified type is not supported + """ + if dtype is Ratio: + return self + elif dtype is float: + return self.numerator / self.denominator + elif dtype is int: + return self.numerator // self.denominator + elif issubclass(dtype, _Ratio): + return self + else: + raise TypeError(f"Cannot convert Ratio to {dtype}") + + +class ScaledBasis: + """A class representing a scaled basis element in CuTe's layout algebra. + + ScaledBasis is used to represent elements in the layout algebra, particularly + in the context of composition operations. It consists of a value (scale) and + a mode that identifies mode of the basis element. + + :param value: The scale value + :type value: Union[int, Integer, Ratio, ir.Value] + :param mode: The mode identifying the basis element + :type mode: Union[int, List[int]] + :raises TypeError: If mode is not an integer or list of integers + + **Examples:** + + .. code-block:: python + + # Create a scaled basis with integer scale and mode + sb1 = ScaledBasis(2, 0) # 2 * E(0) + + # Create a scaled basis with a Ratio scale + sb2 = ScaledBasis(Ratio(1, 2), 1) # (1/2) * E(1) + + # Create a scaled basis with a list of modes + sb3 = ScaledBasis(4, [0, 1]) # 4 * E([0, 1]) + + # Scaled basis elements are commonly used in layout strides + layout = make_layout((4, 8), stride=(ScaledBasis(2, 0), ScaledBasis(1, 1))) + + # This creates a layout with strides (2@0, 1@1) representing + # a coordinate system where each dimension has its own basis + + # Example: Mapping coordinates to indices using the layout + coord = (2, 3) + idx = crd2idx(coord, layout) # Maps (2, 3) to (4, 3) + """ + + def __init__(self, value, mode) -> None: + if isinstance(mode, int): + self._mode = [mode] + else: + if any(not isinstance(x, int) for x in mode): + raise TypeError("Mode must be a list of integers") + self._mode = mode + + self._value = value + + def is_static(self) -> bool: + """Check if the value is statically known. + + :return: True if the value is not a dynamic expression + :rtype: bool + """ + return not is_dynamic_expression(self._value) + + def to(self, dtype): + """Convert to another type. + + :param dtype: The target type for conversion + :type dtype: type + :return: The ScaledBasis converted to the specified type + :raises TypeError: If conversion to the specified type is not supported + """ + if dtype is ScaledBasis: + return self + elif dtype is _ScaledBasis: + if isinstance(self._value, Ratio): + scale = self._value + elif isinstance(self._value, Integer): + scale = self._value.ir_value() + else: + scale = self._value + + if isinstance(scale, IntValue): + return _ScaledBasis(scale.get_typed_value(), self._mode) + else: + return _ScaledBasis(scale, self._mode) + else: + raise TypeError(f"Cannot convert ScaledBasis to {dtype}") + + def __str__(self): + return f"{self.to(_ScaledBasis).__str__()}" + + def __hash__(self): + if isinstance(self.mode, list): + return hash((self.value, tuple(self.mode))) + else: + return hash((self.value, self.mode)) + + @property + def value(self): + """Get the scale value. + + :return: The scale value + """ + return self._value + + @property + def mode(self) -> List[int]: + """Get the mode identifying the basis element. + + :return: The mode as a list of integers + :rtype: List[int] + """ + return self._mode + + def __eq__(self, other): + if isinstance(other, ScaledBasis): + return self.value == other.value and self.mode == other.mode + else: + return False + + def __rmul__(self, scale: Union[Int, ir.Value, Ratio]) -> "ScaledBasis": + """Right multiplication by a scale factor. + + This operation is used in layout algebra to scale basis elements, + which is essential for operations like composition and partitioning. + + :param scale: The scale factor + :type scale: Union[Int, ir.Value, Ratio] + :return: A new scaled basis element + :rtype: ScaledBasis + :raises TypeError: If scale is not of a supported type + :raises NotImplementedError: If scaling a basis element with a ratio value + """ + if not isinstance(scale, (int, Integer, Ratio, ir.Value)): + raise TypeError( + f"scale must be an integer or a ratio, but got {type(scale)}" + ) + if isinstance(self.value, Ratio): + raise NotImplementedError( + "scaling a basis element having a ratio is not supported" + ) + + value = self.value + + if not isinstance(value, (Integer, Ratio, int, cutlass_arith.ArithValue)): + raise TypeError(f"Don't support {type(value)} for ScaledBasis") + + # Lift to IntValue type to preserve type info as much as possible + if isinstance(scale, cutlass_arith.ArithValue): + scale = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(scale, Int32))) + + if isinstance(value, cutlass_arith.ArithValue): + value = IntValue(_pack_int_tuple(cutlass_arith.int_to_int(value, Int32))) + elif isinstance(value, Integer): + value = value.ir_value() + + return ScaledBasis(scale * value, self.mode) # type: ignore + + +def E(mode: Union[int, List[int]]) -> ScaledBasis: + """Create a unit ScaledBasis element with the specified mode. + + This function creates a ScaledBasis with value 1 and the given mode. + The mode represents the coordinate axis or dimension in the layout. + + :param mode: The mode (dimension) for the basis element, either a single integer or a list of integers + :type mode: Union[int, List[int]] + :return: A ScaledBasis with value 1 and the specified mode + :rtype: ScaledBasis + :raises TypeError: If mode is not an integer or a list + + **Examples:** + + .. code-block:: python + + # Create a basis element for the first dimension (mode 0) + e0 = E(0) + + # Create a basis element for the second dimension (mode 1) + e1 = E(1) + + # Create a basis element for a hierarchical dimension + e_hier = E([0, 1]) + """ + if isinstance(mode, int): + mode = [mode] + + if not isinstance(mode, list): + raise TypeError(f"expects a list, got {type(mode)}") + + if not mode: + return 1 + + return ScaledBasis(1, mode) + + +def get_divisibility(x: Union[int, Integer]) -> int: + if isinstance(x, int): + return x + + if isinstance(x, Integer): + x = x.value + + if isinstance(x, IntValue): + return x.divisibility + else: + return 1 + + +@ir.register_value_caster(_cute_ir.SwizzleType.get_static_typeid(), replace=True) +class Swizzle(ir.Value): + """ + Swizzle is a transformation that permutes the elements of a layout. + + Swizzles are used to rearrange data elements to improve memory access patterns + and computational efficiency. + + Swizzle is defined by three parameters: + - MBase: The number of least-significant bits to keep constant + - BBits: The number of bits in the mask + - SShift: The distance to shift the mask + + The mask is applied to the least-significant bits of the layout. + + .. code-block:: + + 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + ^--^ MBase is the number of least-sig bits to keep constant + ^-^ ^-^ BBits is the number of bits in the mask + ^---------^ SShift is the distance to shift the YYY mask + (pos shifts YYY to the right, neg shifts YYY to the left) + + e.g. Given + 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + + the result is + 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ `xor` YY + + """ + + def __str__(self): + # Cut off the MLIR type's string for making pretty_str more concise + return self.type.__str__()[15 : 15 + 8] + + +@ir.register_value_caster(_cute_ir.LayoutType.get_static_typeid(), replace=True) +class _Layout(Layout): + """Layout is CuTe's core abstraction for representing tensor layouts. + + A Layout maps from a logical coordinate space to an index space, defined by a + pair of (Shape, Stride). The Shape defines the abstract dimensions of the Layout, + while the Stride defines how coordinates within the Shape map to linear indices. + + Layouts present a common interface to multidimensional array access that abstracts + away the details of how array elements are organized in memory. This allows algorithms + to be written generically, so that layouts can change without requiring code changes. + + CuTe layouts are inherently hierarchical, constructed from smaller, nested layouts + that can represent complex mappings required by GPU tensor instructions. They support + a rich algebra of operations including concatenation, coalescence, composition, + complement, and inversion. + + :ivar shape: An IntTuple representing the dimensions of the layout. + :ivar stride: An IntTuple representing the strides of the layout. + :ivar max_alignment: The maximum alignment of the layout. + + **Examples:** + + .. code-block:: python + + # Creating a layout with shape (4,8) and default stride (layout left / "column major") + layout = cute.make_layout((4, 8)) + + # Creating a layout with explicit shape and stride + layout = cute.make_layout((4, 8), stride=(8, 1)) + + # Accessing a specific coordinate: (2, 3) -> 2 * 8 + 3 * 1 = 19 + idx = cute.crd2idx((2, 3), layout) + """ + + def __init__(self, op_result) -> None: + """Initialize a Layout object. + + :param op_result: The operation result value to wrap. + """ + super().__init__(op_result) + + def __str__(self) -> str: + """Return a string representation of the layout. + + :return: A string in the format "shape:stride". + """ + return f"{pretty_str(self.shape)}:{pretty_str(self.stride)}" + + @property + def shape(self, *, loc=None, ip=None) -> Shape: + """Get the shape of the layout. + + The shape defines the dimensions and structure of the layout's + coordinate space. + + :param loc: Optional location information for debugging. + :param ip: Optional insertion point for IR generation. + :return: The hierarchical shape of the layout. + """ + return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) + + @property + def stride(self, *, loc=None, ip=None) -> Stride: + """Get the stride of the layout. + + The stride defines how coordinates map to linear indices in memory. + + :param loc: Optional location information for debugging. + :param ip: Optional insertion point for IR generation. + :return: The hierarchical stride of the layout. + """ + return _unpack_x_tuple( + _cute_ir.get_stride(self, loc=loc, ip=ip), loc=loc, ip=ip + ) + + @property + def max_alignment(self) -> int: + """Get the maximum alignment of the layout. + + :return: The maximum alignment in bytes. + """ + return self.type.max_alignment + + def __eq__(self, other) -> Union[bool, Boolean]: + """Check if this layout is equal to another layout. + + Two layouts are equal if they have the same shape and stride. + + :param other: The layout to compare with. + :return: True if layouts are equal, False otherwise. + May return an IR value for dynamic layouts. + """ + if isinstance(other, Layout): + if is_static(self.type) and is_static(other.type): + return self.type == other.type + return Boolean(_cute_ir.equal(self, other)) + else: + return False + + def __req__(self, other) -> Union[bool, Boolean]: + """Reflected equality check. + + :param other: The layout to compare with. + :return: Result of other.__eq__(self). + """ + if isinstance(other, Layout): + return other.__eq__(self) + return False + + def __ne__(self, other) -> Union[bool, Boolean]: + """Check if this layout is not equal to another layout. + + :param other: The layout to compare with. + :return: True if layouts are not equal, False otherwise. + """ + if isinstance(other, Layout): + if is_static(self.type) and is_static(other.type): + return self.type != other.type + return Boolean(not_(_cute_ir.equal(self, other))) + else: + return True + + def __rne__(self, other) -> Union[bool, Boolean]: + """Reflected inequality check. + + :param other: The layout to compare with. + :return: Result of other.__ne__(self). + """ + if isinstance(other, Layout): + return other.__ne__(self) + return False + + def __getitem__(self, idx: int) -> Layout: + """ + Top-level `get` to provide a syntax similar to `tuple`. + """ + return get(self, mode=[idx]) + + @dsl_user_op + def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: + return crd2idx(coord, self, loc=loc, ip=ip) + + @dsl_user_op + def get_hier_coord(self, idx, *, loc=None, ip=None) -> Coord: + """Get the hierarchical coordinate corresponding to a linear index. + + This method maps from a linear index back to the logical coordinate + in the layout's coordinate space. + + :param idx: The linear index to convert. + :return: The hierarchical coordinate corresponding to the index. + + **Examples:** + + .. code-block:: python + + layout = make_layout((4, 8), stride=(8, 1)) + + # map linear index back to coordinate: 5 -> (1, 1) + coord = get_hier_coord(5, layout) + """ + idx_val = Int32(idx).ir_value() + crd = _cute_ir.get_hier_coord(idx_val, self, loc=loc, ip=ip) + return _unpack_x_tuple(crd) + + @dsl_user_op + def get_flat_coord(self, idx, *, loc=None, ip=None) -> Coord: + idx_val = Int32(idx).ir_value() + res = _cute_ir.get_flat_coord(idx_val, self, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@ir.register_value_caster(_cute_ir.ComposedLayoutType.get_static_typeid(), replace=True) +class ComposedLayout(ir.Value): + r"""ComposedLayout represents the functional composition of layouts in CuTe. + + A ComposedLayout is formed by the composition of three components: + inner o offset o outer, where: + + - inner: The inner layout or swizzle that is applied last + - offset: An integer tuple representing a coordinate offset + - outer: The outer layout that is applied first + + ComposedLayout implements the functional composition operation where: + + .. math:: + + R(c) := (inner \\circ offset \\circ outer)(c) := inner(offset + outer(c)) + + This composition allows for complex transformations of coordinates and indices, + enabling operations like tiling, partitioning, and reshaping of data. + + :ivar inner: The inner layout or swizzle component + :ivar offset: The coordinate offset applied between inner and outer layouts + :ivar outer: The outer layout component + :ivar max_alignment: The maximum alignment of the composed layout + + **Examples:** + + .. code-block:: python + + # Create a composed layout with inner layout, offset, and outer layout + + # inner layout: (4, 8):(1, 4) + inner_layout = make_layout((4, 8)) + + offset = (0, 0) + + # outer layout: (2, 2):(1@0, 1@1) + outer_layout = make_layout((2, 2), stride=(1 * E(0), 1 * E(1))) + + # composed layout: (inner o offset o outer) + composed = make_composed_layout(inner_layout, offset, outer_layout) + + # Accessing components of the composed layout + inner = composed.inner + offset = composed.offset + outer = composed.outer + + # map coordinate (0, 1) to linear index + # - outer(0, 1) = (0, 1) + # - offset + outer(0, 1) = (0, 1) + # - inner(0, 1) = 0 * 1 + 1 * 4 = 4 + idx = crd2idx((0, 1), composed) + + # Composition is used in many tiling operations + # For example, in logical_product, raked_product, and blocked_product + """ + + def __init__(self, value) -> None: + """Initialize a ComposedLayout object. + + :param value: The operation result value to wrap. + """ + super().__init__(value) + + def __str__(self) -> str: + return f"{pretty_str(self.inner)} o {pretty_str(self.offset)} o {pretty_str(self.outer)}" + + @property + def inner(self, *, loc=None, ip=None) -> Union[Swizzle, Layout]: + return _cute_ir.composed_get_inner(self, loc=loc, ip=ip) + + @property + def offset(self, *, loc=None, ip=None) -> IntTuple: + return _unpack_x_tuple(_cute_ir.composed_get_offset(self, loc=loc, ip=ip)) + + @property + def outer(self, *, loc=None, ip=None) -> Layout: + return _cute_ir.composed_get_outer(self, loc=loc, ip=ip) + + @property + def shape(self, *, loc=None, ip=None) -> Shape: + return _unpack_x_tuple(_cute_ir.get_shape(self, loc=loc, ip=ip), loc=loc, ip=ip) + + @property + def max_alignment(self) -> int: + return self.type.max_alignment + + def __eq__(self, other) -> Union[bool, Boolean]: + if isinstance(other, ComposedLayout): + if is_static(self.type) and is_static(other.type): + return self.type == other.type + else: + raise NotImplementedError( + f"runtime comparison of composed layouts is not supported, got `{self}` and `{other}`" + ) + else: + return False + + def __req__(self, other) -> Union[bool, Boolean]: + if isinstance(other, ComposedLayout): + return Boolean(other.__eq__(self)) + return False + + def __ne__(self, other) -> Union[bool, Boolean]: + return not self.__eq__(other) + + def __rne__(self, other) -> Union[bool, Boolean]: + if isinstance(other, ComposedLayout): + return other.__ne__(self) + return False + + def __getitem__(self, idx: int) -> "ComposedLayout": + """ + Top-level `get` to provide a syntax similar to `tuple`. + """ + return get(self, mode=[idx]) + + @dsl_user_op + def __call__(self, coord: Coord, loc=None, ip=None) -> IntTuple: + return crd2idx(coord, self, loc=loc, ip=ip) + + +@ir.register_value_caster(_cute_ir.PtrType.get_static_typeid(), replace=True) +class _Pointer(Pointer): + """ + A pointer class representing a memory address with specific properties. + + Pointers are a fundamental type of iterator/engine that support random-access operations. + They can be offset by elements of a layout's codomain and dereferenced to produce values. + + :param value: The MLIR operation result value to initialize the pointer with + :type value: ir.Value + + :ivar type: The MLIR type of the pointer + :vartype type: Type + :ivar value_type: The type of value this pointer points to + :vartype value_type: Type + :ivar memspace: The memory space where the pointer data resides (e.g., gmem, smem, rmem) + :vartype memspace: AddressSpace + + :note: When composed with a layout, a pointer forms a tensor: T = E ∘ L, where E is the pointer + and L is the layout. The tensor evaluates the layout by mapping a coordinate c to the + codomain, offsets the pointer accordingly, and dereferences the result: + T(c) = (E ∘ L)(c) = *(E + L(c)) + """ + + def __init__(self, value) -> None: + assert isinstance(value, ir.Value) + self.value = ir.Value(value) + + def __str__(self) -> str: + # Cut off the MLIR type's string for making pretty_str more concise + return self.type.__str__()[6:] + + def __get_mlir_types__(self): + return [self.value.type] + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + # Only expecting single value of _Pointer instance or ir.Value + # In this context, a _Pointer instance is an encapsulated ir.Value which is automatically created + # by value caster for cute.ptr typed values + assert len(values) == 1, f"Expected 1 value, but got {len(values)}" + assert isinstance( + values[0], (_Pointer, ir.Value) + ), f"Expected _Pointer or ir.Value, but got {type(values[0])}" + return _Pointer( + values[0] if isinstance(values[0], ir.Value) else values[0].value + ) + + @property + @lru_cache_ir() + def dtype(self) -> Type[Numeric]: + return Numeric.from_mlir_type(self.value.type.value_type) + + @property + def alignment(self) -> int: + return self.type.alignment + + @property + def max_alignment(self) -> int: + return self.type.max_alignment + + @property + @lru_cache_ir() + def memspace(self) -> AddressSpace: + return AddressSpace(self.type.address_space) + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + # Only use if you absolutely need to get the LLVM pointer Value + @property + @lru_cache_ir() + def llvm_ptr(self, *, loc=None, ip=None) -> ir.Value: + """ + Get the LLVM pointer representation of this pointer. + + :param loc: The source location for the operation, defaults to None + :type loc: Location, optional + :param ip: The insertion point for the operation, defaults to None + :type ip: InsertionPoint, optional + :return: The LLVM pointer representation + :rtype: ir.Value + """ + llvm_ptr_ty = llvm.PointerType.get(self.memspace.value) + return builtin.unrealized_conversion_cast( + [llvm_ptr_ty], [self.value], loc=loc, ip=ip + ) + + def __add__(self, offset: IntTuple) -> Pointer: + """ + Offset the pointer by elements of a layout's codomain. + + :param offset: The offset to add to the pointer + :type offset: IntTuple + :return: A new pointer offset by the specified amount + :rtype: ir.Value + """ + offset = _pack_int_tuple(offset) + return _cute_ir.add_offset(self.value, offset=offset) + + @dsl_user_op + def toint(self, *, loc=None, ip=None): + if self.memspace in (AddressSpace.gmem, AddressSpace.generic): + res_type = Int64 + else: + res_type = Int32 + + return res_type( + _cute_ir.ptrtoint(res_type.mlir_type, self.value, loc=loc, ip=ip) + ) + + @dsl_user_op + def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: + """ + Align a pointer to a specified byte alignment. + + :param min_align: The minimum byte alignment requirement. Must be a power of 2. + :type min_align: int + :param loc: The source location for the operation, defaults to None + :type loc: Location, optional + :param ip: The insertion point for the operation, defaults to None + :type ip: InsertionPoint, optional + :return: The aligned new pointer that satisfies alignment request. + :rtype: Pointer + :raises ValueError: If the alignment is not a power of 2. + :raises TypeError: If pointer is in tmem address space. + """ + + if (min_align & (min_align - 1)) != 0: + raise ValueError("Alignment must be a power of 2") + + assert isinstance(self.type, _cute_ir.PtrType) + if self.memspace is AddressSpace.tmem: + raise ValueError("aligning a TMEM pointer is not supported") + + if min_align <= self.alignment: + return self + + dtype = Numeric.from_mlir_type(self.type.value_type) + # Convert pointer to integer + address_int = self.toint(loc=loc, ip=ip) + # Align the address + aligned_address = (address_int + min_align - 1) & ~(min_align - 1) + + return make_ptr( + dtype, + aligned_address, + self.memspace, + assumed_align=min_align, + loc=loc, + ip=ip, + ) + + +@ir.register_value_caster(_cute_ir.MemRefType.get_static_typeid(), replace=True) +@ir.register_value_caster(_cute_ir.CoordTensorType.get_static_typeid(), replace=True) +@ir.register_value_caster( + _cute_nvgpu_ir.SmemDescViewType.get_static_typeid(), replace=True +) +class _Tensor(Tensor): + """A tensor class representing the composition of an iterator (engine) with a layout. + + A tensor evaluates the layout by mapping a coordinate to the codomain, offsets the + iterator accordingly, and dereferences the result to obtain the tensor's value. + Formally: T(c) = (E ∘ L)(c) = *(E + L(c)), where E is the iterator/engine and L is the layout. + + :param value: The MLIR operation result value to initialize the tensor with + :type value: ir.Value + :param dtype: The user specified data type of the tensor elements. It could be \ + different from the underlying dtype in the iterator. The default is None. + :type dtype: Type[Numeric], optional + + Attributes: + iterator: The pointer or iterator (engine) component of the tensor + layout: The layout component defining the mapping from coordinates to offsets + shape: The shape of the tensor, inherited from the layout + stride: The stride of the tensor, inherited from the layout + element_type: The data type of the tensor elements + memspace: The memory space where the tensor data resides + + Notes: + - The tensor supports both direct element access via coordinates and slicing operations + - Load/store operations are only supported for specific memory spaces (rmem, smem, gmem, generic) + - For composed layouts, stride information is not directly accessible + - Dynamic layouts do not support vector load/store operations + + **Examples:** + + .. code-block:: python + + # Create a tensor with shape (4,8) in row-major layout + tensor = make_tensor(ptr, make_layout(shape=(4,8), stride=(8,1))) + + # Access individual element + val = tensor[0, 0] # or val = tensor[(0, 0)] + + # Slice operation - get first column + subtensor = tensor[None, 0] # or subtensor = tensor[(None, 0)] + """ + + def __init__(self, value, dtype: Optional[Type[Numeric]] = None): + self._dtype = dtype + if isinstance(value, ir.Value): + self.value = value + elif isinstance(value, _Tensor): + self.value = value.value + else: + raise TypeError(f"Expected ir.Value or core._Tensor, got {type(value)}") + + # Set iterator + iter_val = _cute_ir.get_iter(self.value) + if isinstance(iter_val, Pointer): + self._iterator = iter_val + elif isinstance(iter_val.type, _cute_ir.IntTupleType): + self._iterator = _unpack_x_tuple(iter_val) + elif isinstance(iter_val, ir.Value): + # Example: SMEM descriptor iterator, not well supported today + self._iterator = iter_val + else: + raise TypeError(f"unsupported iterator type, got {type(iter_val)}") + + # Set dtype + if self._dtype is None: + if is_int_tuple(self.iterator): + self._dtype = IntTuple + elif isinstance(self.iterator, Pointer): + self._dtype = self.iterator.value_type + elif isinstance(self.type, _cute_nvgpu_ir.SmemDescViewType): + # SmemDescViewType do not need dtype + self._dtype = None + else: + raise TypeError(f"unsupported iterator type, got {type(self.iterator)}") + + def __str__(self): + return f"tensor<{pretty_str(self.iterator)} o {pretty_str(self.layout)}>" + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + # Only expecting single value of _Tensor or ir.Value + # In this context, a _Tensor instance is an encapsulated ir.Value which is automatically created + # by value caster for MemRef/CoordTensor/SmemDescView typed values + assert len(values) == 1, f"Expected 1 value, but got {len(values)}" + assert isinstance( + values[0], (_Tensor, ir.Value) + ), f"Expected _Tensor or ir.Value, but got {type(values[0])}" + return _Tensor( + values[0] if isinstance(values[0], ir.Value) else values[0].value, + dtype=self.element_type, + ) + + # Cheat to let `Type(_Tensor())` to return cute.Tensor + @property + def __class__(self) -> Type[Tensor]: + return Tensor + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + @dsl_user_op + def __getitem__( + self, crd: Coord, *, loc=None, ip=None + ) -> Union[Tensor, Numeric, IntTuple]: + """Access or slice tensor elements using coordinates. + + This method implements + * tensor evaluation T(c) = *(E + L(c)) when `c` is a coordinate without slicing, or + * tensor slicing operations T(c) = make_tensor(E + L(c), slice(L, c)) + where E is the iterator/engine and L is the layout + + :param crd: Coordinate or slice specification for accessing tensor elements + :type crd: Coord + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Tensor element value or sliced subtensor + :rtype: Union[Tensor, ir.Value, IntTuple] + + :raises ValueError: If coordinate access is invalid for the tensor layout + + **Examples:** + + .. code-block:: python + + # Create a tensor with pointer iterator + ptr = make_ptr(cutlass.Float32, 0, cutlass.AddressSpace.gmem) + layout = make_layout((64, 128)) # leftmost mode is major + tensor = make_tensor(ptr, layout) # Tensor using pointer iterator + + # Direct element access loads from memory + val = tensor[0] # Loads element at offset 0 + val = tensor[1] # Loads element at offset 4 (4bytes per Float32) + val = tensor[(0, 1)] # Loads element at offset 64 + + # Create a coord tensor + layout = make_layout((64, 128), stride=(1 * E(0), 1 * E(1))) + tensor = make_tensor((128, 128), layout) + + # Direct element access + val = tensor[0] # Returns (128, 128) + val = tensor[(0, 1)] # Returns (128, 129) + + # Slice access + sliced = view[(3, None)] # Returns tensor slice + + .. note:: + Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar + dereference operations. Attempting to set individual elements of tensors with + these element types will result in errors. + + **Examples:** + + .. code-block:: python + + # Unsupported operations with sub-byte types: + ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + # The following will raise an error: + val = tensor[0] # Error: sub-byte scalar dereference not supported + + # Similarly for other sub-byte types: + ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + val = tensor[0] # Error: sub-byte scalar dereference not supported + """ + if has_underscore(crd): + return slice_(self.value, crd) + elif isinstance(self.type, _cute_ir.CoordTensorType): + res = _cute_ir.get_iter(slice_(self, crd).value, loc=loc, ip=ip) + return _unpack_x_tuple(res) + else: + self._check_can_load_store() + self._check_can_dereference() + + crd_val = _pack_coord(crd, loc=loc, ip=ip) + data_val = _cute_ir.memref_load(self.value, crd_val, loc=loc, ip=ip) + return self.element_type(data_val) + + def _cvt_to_dest(self, data: Union["TensorSSA", Numeric], *, loc=None, ip=None): + orig_dtype = data.dtype + # Implicit upcast to wider type + if ( + data.dtype.is_same_kind(self.element_type) + and self.element_type.width >= data.dtype.width + ): + data = data.to(self.element_type, loc=loc, ip=ip) # type: ignore + + if data.dtype.width != self.element_type.width: + raise ValueError( + f"Type mismatch, store {orig_dtype} (-> {data.dtype}) " + f"to Tensor with element type {self.element_type}" + ) + + if data.dtype is Boolean and self.element_type is Boolean: + # Boolean Numeric and Boolean TensorSSA both hold i1 value, but we need int8 value store to memory + val = data.ir_value_int8() + else: + val = data.ir_value() + return val + + @dsl_user_op + def __setitem__( + self, + crd: Coord, + data: Union[int, float, ir.Value, Numeric, "TensorSSA"], + *, + loc=None, + ip=None, + ) -> None: + """Set tensor elements at specified coordinates. + + Assigns values to tensor elements through direct coordinate access or slice assignment. + For slice assignment, the value must be a TensorSSA with matching shape. + + :param crd: Coordinate or slice specification for tensor element assignment + :type crd: Coord + :param data: Value to assign - can be scalar or TensorSSA for slice assignment + :type data: Union[int, float, ir.Value, Numeric, TensorSSA] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises ValueError: If tensor type doesn't support load/store operations + :raises ValueError: If slice assignment value is not a TensorSSA + :raises ValueError: If value type doesn't match tensor element type + :raises NotImplementedError: If value type is not supported + + .. note:: + Sub-byte types like Float4E2M1FN and Float6E3M2FN are not supported for scalar + dereference operations. Attempting to set individual elements of tensors with + these element types will result in errors. + + **Examples:** + + .. code-block:: python + + # Unsupported operations with sub-byte types: + ptr = make_ptr(cutlass.Float4E2M1FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + # The following will raise an error: + tensor[0] = 1.0 # Error: sub-byte scalar dereference not supported + + # Similarly for other sub-byte types: + ptr = make_ptr(cutlass.Float6E3M2FN, 0, cutlass.AddressSpace.gmem) + tensor = make_tensor(ptr, layout) + tensor[0] = 0.5 # Error: sub-byte scalar dereference not supported + """ + self._check_can_load_store() + + # convert scalar type + if not has_underscore(crd): + self._check_can_dereference() + # First, convert ir.Value to Numeric + if isinstance(data, ir.Value): + data = as_numeric(data) + elif isinstance(data, (int, float, bool)): + data = as_numeric(data) + + if not isinstance(data, Numeric): + raise ValueError(f"unsupported data type: {type(data)}") + + # Implicit upcast to wider type + val = self._cvt_to_dest(data, loc=loc, ip=ip) + if val.type != self.type.value_type: + raise ValueError( + f"type mismatch, store {val.type} to {self.element_type}" + ) + + crd_val = _pack_coord(crd, loc=loc, ip=ip) + _cute_ir.memref_store(self.value, crd_val, val, loc=loc, ip=ip) + else: + if not isinstance(data, TensorSSA): + raise ValueError(f"expects TensorSSA, but got {data}") + + self.__getitem__(crd).store(data, loc=loc, ip=ip) # type: ignore + + @property + def __class__(self) -> Type[Tensor]: + return Tensor + + # Make it behave as if it inherited from ir.Value + @property + @lru_cache_ir() + def type(self) -> ir.Type: + return self.value.type + + @property + def iterator(self) -> Union[Pointer, IntTuple]: + return self._iterator + + @property + def layout(self) -> Layout: + return _cute_ir.get_layout(self.value) + + @property + def shape(self) -> Shape: + return self.layout.shape + + @property + def stride(self) -> Stride: + if isinstance(self.type, _cute_ir.ComposedLayoutType): + raise ValueError(f"can't get stride from composed layout") + return self.layout.stride + + @property + def leading_dim(self) -> Union[int, Tuple[int], None]: + """Get the leading dimension of this Tensor. + + :return: The index or indices of the first mode (from left to right) with stride 1 + :rtype: Union[int, Tuple[int], None] + :returns: + - int: Single leading dimension index if found + - Tuple[int]: Tuple of indices for nested leading dimensions + - None: If no leading dimension is found + + :postcondition: ``get(self.stride(), mode=self.leading_dim()) == 1 if self.leading_dim() != None else True`` + """ + return leading_dim(self.shape, self.stride) + + @property + @lru_cache_ir() + def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: + return self._dtype + + @property + @lru_cache_ir() + def memspace(self) -> AddressSpace: + if isinstance(self.iterator, Pointer): + return self.iterator.memspace + + raise ValueError(f"{self} doesn't have memspace") + + @dsl_user_op + def load(self, *, loc=None, ip=None) -> "TensorSSA": + """Load tensor elements as a vector. + + Loads all elements of the tensor into a vector representation, assuming the tensor + has a static shape and is in a memory space that supports load operations. + + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Vector representation of tensor elements + :rtype: TensorSSA + + :raises ValueError: If tensor has dynamic layout + :raises ValueError: If tensor memory space doesn't support load operations + """ + if not is_static(self.shape): + raise ValueError("dynamic layout doesn't support load") + + self._check_can_load_store() + + res_vect = _cute_ir.memref_load_vec(self.value, row_major=True, loc=loc, ip=ip) + if self.element_type is Boolean: + assert ( + res_vect.type.element_type == T.i8() + ), f"Boolean tensor must be stored as i8 in memory, but got {res_vect.type.element_type}" + zeros = full_like(self, 0, Int8, loc=loc, ip=ip) + res_vect = arith.cmpi( + arith.CmpIPredicate.ne, res_vect, zeros, loc=loc, ip=ip + ) + return TensorSSA(res_vect, self.shape, self.element_type) + + @dsl_user_op + def store(self, data: "TensorSSA", *, loc=None, ip=None): + """Store vector data into tensor. + + Stores vector data into the tensor, assuming matching shapes and a memory space + that supports store operations. + + :param data: Vector data to store into tensor + :type data: TensorSSA + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises ValueError: If tensor has dynamic layout + :raises ValueError: If tensor memory space doesn't support store operations + :raises ValueError: If data shape doesn't match tensor shape + """ + if not isinstance(data, TensorSSA): + raise ValueError(f"Expects TensorSSA, but got {type(data)}") + + if not is_static(self.shape): + raise ValueError("Dynamic layout doesn't support vectorized store") + + self._check_can_load_store() + + n_elems = size(self.shape, loc=loc, ip=ip) + if n_elems != size(data.shape, loc=loc, ip=ip): + raise ValueError( + f"lhs and rhs must have the same shape, but got {self.shape} and {data.shape}" + ) + + elem_mlir_type = cutlass_arith.element_type(data.dtype.mlir_type) + if cutlass_arith.is_narrow_precision(elem_mlir_type): + if elem_mlir_type.width * n_elems % 32 != 0: + raise ValueError( + f"narrow precision type must be 32-bit aligned vector, but got {elem_mlir_type} with {n_elems} elements" + ) + + # Implicit upcast to wider type + new_data = self._cvt_to_dest(data, loc=loc, ip=ip) + + return _cute_ir.memref_store_vec( + new_data, self.value, row_major=True, loc=loc, ip=ip + ) + + @dsl_user_op + def fill(self, value: Numeric, *, loc=None, ip=None) -> None: + """Fill tensor with a constant value. + + Fills all elements of the tensor with the specified value, assuming static size + and supported memory space. + + :param value: Value to fill tensor with + :type value: Union[int, float] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + + :raises NotImplementedError: If tensor has dynamic size + + **Examples:** + + .. code-block:: python + + # Create tensor from numpy array + b = np.random.randn(4, 8).astype(np.float32) + tensor = from_dlpack(b) + + # Fill tensor with constant value + tensor.fill(0.5) # All elements become 0.5 + """ + self._check_can_load_store() + + sz = size(self, loc=loc, ip=ip) + if type(sz) is not int: + raise NotImplementedError(f"dynamic size is not supported: {self.type}") + + # Should we cast to destination type even with narrow cast? + dst_type = self.element_type + value = dst_type(value) + + self[None] = full(self.shape, fill_value=value, dtype=dst_type, loc=loc, ip=ip) + + def _check_can_load_store(self): + if not isinstance(self.type, _cute_ir.MemRefType) or not self.memspace in ( + AddressSpace.rmem, + AddressSpace.smem, + AddressSpace.gmem, + AddressSpace.generic, + ): + raise ValueError(f"{self} doesn't support load and store") + + def _check_can_dereference(self): + # Check for sub-byte types and raise error if needed + if self.element_type.width % 8 != 0 and self.element_type is not Boolean: + raise ValueError( + f"Sub-byte scalar dereference not supported for type {self.element_type}" + ) + + +@dsl_user_op +def print_tensor( + tensor: Union[Tensor, "TensorSSA"], *, verbose: bool = False, loc=None, ip=None +): + """Print content of the tensor in human readable format. + + Outputs the tensor data in a structured format showing both metadata + and the actual data values. The output includes tensor type information, + layout details, and a formatted array representation of the values. + + :param tensor: The tensor to print + :type tensor: Tensor + :param verbose: If True, includes additional debug information in the output + :type verbose: bool + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer for IR generation, defaults to None + :type ip: insertion pointer, optional + :raises NotImplementedError: If the tensor type doesn't support trivial dereferencing + + **Example output:** + + .. code-block:: text + + tensor(raw_ptr<@..., Float32, generic, align(4)> o (8,5):(5,1), data= + [[-0.4326, -0.5434, 0.1238, 0.7132, 0.8042], + [-0.8462, 0.9871, 0.4389, 0.7298, 0.6948], + [ 0.3426, 0.5856, 0.1541, 0.2923, 0.6976], + [-0.1649, 0.8811, 0.1788, 0.1404, 0.2568], + [-0.2944, 0.8593, 0.4171, 0.8998, 0.1766], + [ 0.8814, 0.7919, 0.7390, 0.4566, 0.1576], + [ 0.9159, 0.7577, 0.6918, 0.0754, 0.0591], + [ 0.6551, 0.1626, 0.1189, 0.0292, 0.8655]]) + """ + if isinstance(tensor, TensorSSA): + tmp = make_fragment(tensor.shape, tensor.dtype) + tmp.store(tensor) + tensor = tmp + + if not isinstance(tensor.type, _cute_ir.MemRefType): + raise NotImplementedError( + f"printing {tensor} is not supported because it doesn't support trivial dereferencing. " + f"Coordinate Tensor will be supported in the future." + ) + + tensor._check_can_load_store() # type: ignore + + if tensor.element_type.is_integer: + signed = tensor.element_type.signed + else: + signed = False + + _cute_ir.print_view(tensor.value, verbose=verbose, is_signed=signed, loc=loc, ip=ip) + + +#################################################################################################### +# +# Core API +# +#################################################################################################### + + +# +# Utilties +# + + +@lru_cache_ir() +def is_integer(a) -> bool: + """Check if an object is static integer or dynamic integer""" + return isinstance(a, (int, Integer)) or ( + isinstance(a, ir.Value) + and isinstance(a.type, (ir.IntegerType, _cute_ir.ConstrainedIntType)) + ) + + +def is_valid_leaf(a) -> bool: + """ + Returns whether `a` has a type that is valid for a CuTe tuple's leaf. + """ + return ( + is_integer(a) + or (a is None) + or isinstance(a, (ScaledBasis, Layout, ComposedLayout)) + ) + + +def is_int_tuple(a) -> bool: + if isinstance(a, tuple): + return all([is_int_tuple(x) for x in a]) + else: + return is_integer(a) + + +def is_static(x: Union[ir.Type, ir.Value, XTuple]) -> bool: + """Check if a value is statically known at compile time. + + In CuTe, static values are those whose values are known at compile time, + as opposed to dynamic values which are only known at runtime. + + :param x: The value to check + :type x: Union[ir.Type, ir.Value, XTuple] + :return: True if the value is static, False otherwise + :rtype: bool + :raises TypeError: If an unsupported type is provided + """ + if isinstance(x, ir.Type): + return _cute_ir.is_static(x) + elif isinstance(x, tuple): + return all(is_static(a) for a in x) + # Can it be a static int? + elif isinstance(x, Numeric): + return False + elif is_dynamic_expression(x): + return _cute_ir.is_static(x.type) + elif isinstance(x, (bool, int, float)) or x is None: + return True + elif isinstance(x, ScaledBasis): + return x.is_static() + else: + raise TypeError(f"unsupported type {x}") + + +def has_underscore(a: XTuple) -> bool: + if type(a) is tuple: + return any([has_underscore(x) for x in a]) + else: + return a is None + + +def has_scaled_basis(a: XTuple) -> bool: + """Check if a tuple or its nested elements contain ScaledBasis objects. + + ScaledBasis objects are fundamental components in CuTe layouts, + representing the basis vectors of coordinate systems. + + :param a: The tuple to check + :type a: XTuple + :return: True if the tuple contains ScaledBasis objects, False otherwise + :rtype: bool + """ + if type(a) is tuple: + return any([has_scaled_basis(x) for x in a]) + else: + return isinstance(a, ScaledBasis) + + +def _tuple_str(t: tuple) -> str: + """ + Constructs a string representation of a python tuple without calling __repr__ on its elements. + """ + + def construct_inner_str(t) -> str: + if not isinstance(t, tuple): + return pretty_str(t) + res = "" + l = len(t) + for i in range(l): + res += pretty_str(t[i]) + if i < l - 1: + res += "," + return res + + res = "(" + construct_inner_str(t) + ")" + return res + + +def pretty_str(arg) -> str: + """ + Constructs a concise readable pretty string. + """ + if isinstance(arg, tuple): + # _tuple_str for tuples + return _tuple_str(arg) + elif arg is None: + # We interpret None as underscores for slicers + return "_" + else: + # Fallback to __str__ + return arg.__str__() + + +@dsl_user_op +def printf(*args, loc=None, ip=None) -> None: + """ + Print a value or a list of values. + + It supports c-style printf format as well: + + .. code-block:: python + + a = cute.make_layout(shape=(10, 10), stride=(10, 1)) + b = cutlass.Float32(1.234) + cute.printf(a, b) + cute.printf("a={}, b={}", a, b) + cute.printf("a={}, b=%.2f", a, b) + + :param args: List of values to print + :type args: list + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer, defaults to None + :type ip: insertion pointer, optional + :raises ValueError: If no arguments are provided or if an unsupported argument type is passed + """ + + if len(args) == 0: + raise ValueError("expects at least one argument to print") + + if isinstance(args[0], str): + fmt = args[0] + "\n" + args = args[1:] + else: + fmt = "{}" + ", {}" * (len(args) - 1) + "\n" + + def process_arg(arg): + arg0 = arg.value if isinstance(arg, Numeric) else arg + + if isinstance(arg0, ir.Value): + return arg0 + elif isinstance(arg0, bool): + return const(arg0, Boolean) + elif isinstance(arg0, int): + return const(arg0, Int32) + elif isinstance(arg0, float): + return const(arg0, Float32) + elif has_underscore(arg0): + # Assume it's a coordinate + return _pack_coord(arg0) + elif has_scaled_basis(arg0): + # Assume it's a stride + return _pack_stride(arg0) + elif isinstance(arg0, tuple): + # Assume it's an int_tuple + return _pack_int_tuple(arg0) + elif isinstance(arg0, (_Tensor, _Pointer)): + return arg0.value + else: + raise TypeError(f"unsupported argument type in printf, got {type(arg)}") + + args = [process_arg(a) for a in args] + _cute_ir.print_(args, fmt=fmt, loc=loc, ip=ip) + + +@dsl_user_op +def front(input, *, loc=None, ip=None): + """Recursively get the first element of input. + + This function traverses a hierarchical structure (like a layout or tensor) + and returns the first element at the deepest level. It's particularly useful + for accessing the first stride value in a layout to determine properties like + majorness. + + :param input: The hierarchical structure to traverse + :type input: Union[Tensor, Layout, Stride] + :param loc: Source location where it's called, defaults to None + :type loc: source location, optional + :param ip: Insertion pointer for IR generation, defaults to None + :type ip: insertion pointer, optional + :return: The first element at the deepest level of the input structure + :rtype: Union[int, float, bool, ir.Value] + """ + if rank(input) == 1 and depth(input) == 0: + return input + else: + return front(get(input, mode=[0], loc=loc, ip=ip), loc=loc, ip=ip) + + +@dsl_user_op +def is_major(mode, stride: Stride, *, loc=None, ip=None) -> bool: + """ + Check whether a mode in stride is the major mode. + """ + first_stride = front(get(stride, mode=[mode], loc=loc, ip=ip), loc=loc, ip=ip) + if is_dynamic_expression(first_stride): + return False + return True if first_stride == 1 else False + + +def leading_dim(shape: Shape, stride: Stride) -> Union[int, Tuple[int, ...], None]: + """ + Find the leading dimension of a shape and stride. + + :param shape: The shape of the tensor or layout + :type shape: Shape + :param stride: The stride of the tensor or layout + :type stride: Stride + :return: The leading dimension index or indices + :rtype: Union[int, Tuple[int, ...], None] + + The return value depends on the stride pattern: + + * If a single leading dimension is found, returns an integer index + * If nested leading dimensions are found, returns a tuple of indices + * If no leading dimension is found, returns None + """ + + def pred_fn(val, pos): + # skip dynamic values which can't be compared + # find the candidate target val, stride at this position is 1 + if (not is_dynamic_expression(val)) and (val == 1): + # extract the shape at this position + mode = [pos] if isinstance(pos, int) else list(pos) + s = get(shape, mode) + if is_dynamic_expression(s) or s != 1: + # shape at this position is dynamic value or not 1 + # we found the leading dimension + return True + return False + + return find_if(stride, pred_fn=pred_fn) + + +@dsl_user_op +def find_if( + t: Union[tuple, ir.Value, int], + pred_fn: Callable[[int, Tuple[int, ...]], bool], + *, + loc=None, + ip=None, +) -> Union[int, Tuple[int, ...], None]: + """Find the first position in t where pred_fn(val, pos) returns True. + + :param t: The search space + :type t: Union[tuple, ir.Value, int] + :param pred_fn: A callable object (lambda, function, etc.) that predicates the value and position in t. + It takes the current leaf value and position, returns True if the value or position is satisfied. + :type pred_fn: Callable[[int, Tuple[int, ...]], bool] + :return: Index if found at top level, tuple of indices showing nested position, or None if not found + :rtype: Union[int, Tuple[int, ...], None] + + **Examples:** + + .. code-block:: python + + # Find the first position of x in t + t = (3, 4) + find_if(t, pred_fn=lambda val, pos: val == x) + + .. code-block:: python + + # find the leading dimension + shape = (3, 4) + stride = (4, 1) + # Find value 1 in stride where the corresponding shape is not 1 + def pred_fn(val, pos): + mode = [pos] if isinstance(pos, int) else list(pos) + return val == 1 and get(shape, mode) != 1 + find_if(stride, pred_fn=pred_fn) + """ + + def _find_if_impl(curr, pos, *, loc=None, ip=None): + if isinstance(curr, tuple): + # Recursively search nested tuple + for i in range(rank(curr)): + sub_curr = get(curr, mode=[i], loc=loc, ip=ip) + sub_pos = (pos, i) if isinstance(pos, int) else pos + (i,) + res_pos = _find_if_impl(sub_curr, sub_pos, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + else: + # For leaf values, check if it matches x + if pred_fn(curr, pos): + return pos + return None + + def _check_pred_fn(): + if not callable(pred_fn): + raise TypeError(f"pred_fn must be callable, but got {type(pred_fn)}") + signature = inspect.signature(pred_fn) + if len(signature.parameters) != 2: + raise ValueError( + f"pred_fn must have two parameters (value, pos), but got {len(signature.parameters)}" + ) + + _check_pred_fn() + + for i in range(rank(t)): + curr = get(t, mode=[i], loc=loc, ip=ip) + res_pos = _find_if_impl(curr, i, loc=loc, ip=ip) + if res_pos is not None: + return res_pos + return None + + +@dsl_user_op +def find( + t: Union[tuple, ir.Value, int], + x: int, + *, + loc=None, + ip=None, +) -> Union[int, Tuple[int, ...], None]: + """Find the first position of a value ``x`` in a hierarchical structure ``t``. + + Searches for the first occurrence of x in t, optionally excluding positions + where a comparison value matches. The search can traverse nested structures + and returns either a single index or a tuple of indices for nested positions. + + :param t: The search space + :type t: Union[tuple, ir.Value, int] + :param x: The static integer x to search for + :type x: int + :return: Index if found at top level, tuple of indices showing nested position, or None if not found + :rtype: Union[int, Tuple[int, ...], None] + """ + if not isinstance(x, int): + raise TypeError(f"find() requires a static x to search for, but got {x}") + + def pred_fn(val, pos): + # Skip dynamic values which can't be compared + return not is_dynamic_expression(val) and val == x + + return find_if(t, pred_fn=pred_fn, loc=loc, ip=ip) + + +def transform_leaf(f, *args): + """ + Apply a function to the leaf nodes of nested tuple structures. + + This function traverses nested tuple structures in parallel and applies the function f + to corresponding leaf nodes. All input tuples must have the same nested structure. + + :param f: Function to apply to leaf nodes + :type f: Callable + :param args: One or more nested tuple structures with matching profiles + :return: A new nested tuple with the same structure as the inputs, but with leaf values transformed by f + :raises TypeError: If the input tuples have different nested structures + + Example: + + .. code-block:: python + + >>> transform_leaf(lambda x: x + 1, (1, 2)) + (2, 3) + >>> transform_leaf(lambda x, y: x + y, (1, 2), (3, 4)) + (4, 6) + >>> transform_leaf(lambda x: x * 2, ((1, 2), (3, 4))) + ((2, 4), (6, 8)) + """ + if all(isinstance(t, tuple) for t in args): + return tuple(transform_leaf(f, *_args) for _args in zip(*args)) + elif all(not isinstance(t, tuple) for t in args): + return f(*args) + else: + raise TypeError(f"profile of input tuples doesn't match: {args}") + + +@dsl_user_op +def assume(src, divby=None, *, loc=None, ip=None): + if divby is None: + return src + + if isinstance(src, Integer): + width = type(src).width + src_val = src.ir_value() + else: + width = src.type.width + src_val = src + + res_ty = _cute_ir.ConstrainedIntType.get(divby, width) + assumed_val = _cute_ir.assume(res_ty, src_val, loc=loc, ip=ip) + return type(src)(IntValue(_pack_int_tuple(assumed_val, loc=loc, ip=ip))) + + +@dsl_user_op +def make_swizzle(b, m, s, *, loc=None, ip=None): + # canonicalize to <0, 4, 3> for identity swizzle (as compiler assumes <0, 4, 3>) + if b == 0: + m, s = 4, 3 + ty = ir.Type.parse(f'!cute.swizzle<"S<{b},{m},{s}>">') + return Swizzle(_cute_ir.static(ty, loc=loc, ip=ip)) + + +# +# Tuple API (also used by layouts and tensors) +# + + +def depth(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: + """Returns the depth (nesting level) of a tuple, layout, or tensor. + + The depth of a tuple is the maximum depth of its elements plus 1. + For an empty tuple, the depth is 1. For layouts and tensors, the depth + is determined by the depth of their shape. For non-tuple values (e.g., integers), + the depth is considered 0. + + :param a: The object whose depth is to be determined + :type a: Union[XTuple, Layout, ComposedLayout, Tensor, Any] + :return: The depth of the input object + :rtype: int + + Example: + + .. code-block:: python + + >>> depth(1) + 0 + >>> depth((1, 2)) + 1 + >>> depth(((1, 2), (3, 4))) + 2 + """ + if type(a) is tuple: + if not a: + return 1 + return max(depth(x) for x in a) + 1 + elif isinstance(a, (Layout, ComposedLayout, Tensor)): + return depth(a.shape) + else: + return 0 + + +@lru_cache_ir() +def rank(a: Union[XTuple, Layout, "ComposedLayout"]) -> int: + """Returns the rank (dimensionality) of a tuple, layout, or tensor. + + The rank of a tuple is its length. For layouts and tensors, the rank is + determined by the rank of their shape. For non-tuple values (e.g., integers), + the rank is considered 1 for convenience. + + :param a: The object whose rank is to be determined + :type a: Union[XTuple, Layout, ComposedLayout, Tensor, Any] + :return: The rank of the input object + :rtype: int + + This function is used in layout algebra to determine the dimensionality + of tensors and layouts for operations like slicing and evaluation. + """ + if isinstance(a, tuple): + return len(a) + elif isinstance(a, (Layout, ComposedLayout, Tensor)): + return rank(a.shape) + elif depth(a) == 0: + return 1 + else: + raise TypeError(f"unsupported type in rank, got {type(a)}") + + +def is_congruent( + a: Union[XTuple, Layout, ComposedLayout, Tensor], + b: Union[XTuple, Layout, ComposedLayout, Tensor], +) -> bool: + """ + Returns whether a is congruent to b. + + Congruence is an equivalence relation between hierarchical structures. + + Two objects are congruent if: + * They have the same rank, AND + * They are both non-tuple values, OR + * They are both tuples AND all corresponding elements are congruent. + + Congruence requires type matching at each level -- scalar values match with + scalar values, and tuples match with tuples of the same rank. + + :param a: First object to compare + :type a: Union[XTuple, Layout, ComposedLayout, Tensor] + :param b: Second object to compare + :type b: Union[XTuple, Layout, ComposedLayout, Tensor] + :return: True if a and b are congruent, False otherwise + :rtype: bool + """ + if isinstance(a, (Layout, ComposedLayout, Tensor)): + a = a.shape + if isinstance(b, (Layout, ComposedLayout, Tensor)): + b = b.shape + if isinstance(a, tuple) and isinstance(b, tuple): + return (len(a) == len(b)) and all(is_congruent(x, y) for x, y in zip(a, b)) + if isinstance(a, tuple) or isinstance(b, tuple): + return False + return True + + +def is_weakly_congruent( + a: Union[XTuple, Layout, ComposedLayout, Tensor], + b: Union[XTuple, Layout, ComposedLayout, Tensor], +) -> bool: + """ + Returns whether a is weakly congruent to b. + + Weak congruence is a partial order on hierarchical structures. + + Object X is weakly congruent to object Y if: + * X is a non-tuple value, OR + * X and Y are both tuples of the same rank AND all corresponding elements are weakly congruent. + + Weak congruence allows scalar values to match with tuples, making it useful + for determining whether an object has a hierarchical structure "up to" another. + + :param a: First object to compare + :type a: Union[XTuple, Layout, ComposedLayout, Tensor] + :param b: Second object to compare + :type b: Union[XTuple, Layout, ComposedLayout, Tensor] + :return: True if a and b are weakly congruent, False otherwise + :rtype: bool + """ + if isinstance(a, (Layout, ComposedLayout, Tensor)): + a = a.shape + if isinstance(b, (Layout, ComposedLayout, Tensor)): + b = b.shape + if not isinstance(a, tuple): + return True + if isinstance(a, tuple) and isinstance(b, tuple): + return (len(a) == len(b)) and all( + is_weakly_congruent(x, y) for x, y in zip(a, b) + ) + if isinstance(a, tuple) or isinstance(b, tuple): + return False + return True + + +@overload +def get(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... +@overload +def get(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... +@overload +def get(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... +@overload +def get(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... +@overload +def get(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... +@overload +def get(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... +@overload +def get(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +def get(input, mode: List[int], *, loc=None, ip=None): + """Extract a specific element or sub-layout from a layout or tuple. + + This function recursively traverses the input according to the mode indices, + extracting the element at the specified path. For layouts, this operation + corresponds to extracting a specific sub-layout. + + :param input: The input layout or tuple to extract from + :type input: Layout, ComposedLayout, tuple + :param mode: Indices specifying the path to traverse for extraction + :type mode: List[int] + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The extracted element or sub-layout + :rtype: Layout, ComposedLayout, or element type + :raises ValueError: If any index in mode is out of range + :raises TypeError: If mode contains non-integer elements or if input has unsupported type + + :postcondition: ``get(t, mode=find(x,t)) == x if find(x,t) != None else True`` + + **Examples:** + + .. code-block:: python + + layout = make_layout(((4, 8), (16, 1), 8), stride=((1, 4), (32, 0), 512)) + sub_layout = get(layout, mode=[0, 1]) # 8:4 + sub_layout = get(layout, mode=[1]) # (16, 1):(32, 0) + """ + # Empty mode returns input and terminates the recursive call + if not mode: + return input + + if rank(input) <= mode[0]: + raise ValueError( + f"elements in mode must be less than rank({input}), got {mode}" + ) + + if depth(input) == 0: + return input + elif isinstance(input, tuple): + if not isinstance(mode[0], int): + raise TypeError( + f"invalid element in mode, expects int, got {type(mode[0])}" + ) + return get(input[mode[0]], mode=mode[1:]) + else: + if not isinstance(input, (Layout, ComposedLayout)): + raise TypeError(f"unsupported type of input, got {type(input)}") + return _cute_ir.get( + input.type.get_op_res_type(mode=mode), input, mode=mode, loc=loc, ip=ip + ) + + +@overload +def select(input: Shape, mode, *, loc=None, ip=None) -> Shape: ... +@overload +def select(input: Stride, mode, *, loc=None, ip=None) -> Stride: ... +@overload +def select(input: Coord, mode, *, loc=None, ip=None) -> Coord: ... +@overload +def select(input: IntTuple, mode, *, loc=None, ip=None) -> IntTuple: ... +@overload +def select(input: Tile, mode, *, loc=None, ip=None) -> Tile: ... +@overload +def select(input: Layout, mode, *, loc=None, ip=None) -> Layout: ... +@overload +def select(input: ComposedLayout, mode, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +def select(input, mode: List[int], *, loc=None, ip=None): + """Select modes from input. + + :param input: Input to select from + :type input: Layout, ComposedLayout, tuple + :param mode: Indices specifying which dimensions or elements to select + :type mode: List[int] + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: A new instance with selected dimensions/elements + :rtype: Layout, ComposedLayout, tuple + :raises ValueError: If any index in mode is out of range + :raises TypeError: If the input type is invalid + + **Examples:** + + .. code-block:: python + + # Select specific dimensions from a layout + layout = make_layout((4, 8, 16), stride=(32, 4, 1)) + selected = select(layout, mode=[0, 2]) # Select mode 0 and mode 2 + # Result: (4, 16):(32, 1) + + # Select elements from a tuple + t = (1, 2, 3, 4, 5) + selected = select(t, mode=[0, 2, 4]) # Select mode 0, mode 2, and mode 4 + # Result: (1, 3, 5) + """ + if any((not isinstance(i, int)) or (i >= rank(input)) for i in mode): + raise ValueError( + f"invalid mode element for input of rank {rank(input)}, got {mode=}" + ) + + if isinstance(input, tuple): + return tuple(input[i] for i in mode) + + if not isinstance(input, (Layout, ComposedLayout)): + raise TypeError(f"unsupported type of input, got {type(input)}") + + return _cute_ir.select(input, mode=mode, loc=loc, ip=ip) + + +@overload +def group_modes(input: Shape, begin: int, end: int, *, loc=None, ip=None) -> Shape: ... +@overload +def group_modes( + input: Stride, begin: int, end: int, *, loc=None, ip=None +) -> Stride: ... +@overload +def group_modes(input: Coord, begin: int, end: int, *, loc=None, ip=None) -> Coord: ... +@overload +def group_modes( + input: IntTuple, begin: int, end: int, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def group_modes(input: Tile, begin: int, end: int, *, loc=None, ip=None) -> Tile: ... +@overload +def group_modes( + input: Layout, begin: int, end: int, *, loc=None, ip=None +) -> Layout: ... +@overload +def group_modes( + input: ComposedLayout, begin: int, end: int, *, loc=None, ip=None +) -> ComposedLayout: ... +@overload +def group_modes( + input: Tensor, begin: int, end: int, *, loc=None, ip=None +) -> Tensor: ... + + +@dsl_user_op +def group_modes(input, begin: int, end: int = -1, *, loc=None, ip=None): + """Group modes of a hierarchical tuple or layout into a single mode. + + This function groups a range of modes from the input object into a single mode, + creating a hierarchical structure. For tuples, it creates a nested tuple containing + the specified range of elements. For layouts and other CuTe objects, it creates + a hierarchical representation where the specified modes are grouped together. + + :param input: Input object to group modes from (layout, tuple, etc.) + :type input: Layout, ComposedLayout, tuple, Shape, Stride, etc. + :param beg: Beginning index of the range to group (inclusive) + :type beg: int + :param end: Ending index of the range to group (exclusive) + :type end: int + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: A new object with the specified modes grouped + :rtype: Same type as input with modified structure + + **Examples:** + + .. code-block:: python + + # Group modes in a tuple + t = (2, 3, 4, 5) + grouped = group_modes(t, 1, 3) # (2, (3, 4), 5) + + # Group modes in a layout + layout = make_layout((2, 3, 4, 5)) + grouped_layout = group_modes(layout, 1, 3) # Layout with shape (2, (3, 4), 5) + + # Group modes in a shape + shape = make_shape(2, 3, 4, 5) + grouped_shape = group_modes(shape, 0, 2) # Shape ((2, 3), 4, 5) + """ + if depth(input) == 0 and is_integer(input): + return (input,) + if isinstance(input, tuple): + return (*input[:begin], (input[begin:end]), *input[end:]) + return _cute_ir.group_modes( + input.value if isinstance(input, Tensor) else input, begin, end, loc=loc, ip=ip + ) + + +@overload +def slice_(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... +@overload +def slice_(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... +@overload +def slice_(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... +@overload +def slice_(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... +@overload +def slice_(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... +@overload +def slice_(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... +@overload +def slice_( + src: ComposedLayout, coord: Coord, *, loc=None, ip=None +) -> ComposedLayout: ... +@overload +def slice_(src: Tensor, coord: Coord, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def slice_(src, coord: Coord, *, loc=None, ip=None): + """Perform a slice operation on a source object using the given coordinate. + + This function implements CuTe's slicing operation which extracts a subset of elements + from a source object (tensor, layout, etc.) based on a coordinate pattern. The slice + operation preserves the structure of the source while selecting specific elements. + + :param src: Source object to be sliced (tensor, layout, tuple, etc.) + :type src: Union[Tensor, Layout, IntTuple, Value] + :param coord: Coordinate pattern specifying which elements to select + :type coord: Coord + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new object containing the sliced elements + :rtype: Union[Tensor, Layout, IntTuple, tuple] + :raises ValueError: If the coordinate pattern is incompatible with source + + **Examples:** + + .. code-block:: python + + # Layout slicing + layout = make_layout((4,4)) + + # Select 1st index of first mode and keep all elements in second mode + sub_layout = slice_(layout, (1, None)) + + .. code-block:: python + + # Basic tensor slicing + tensor = make_tensor(...) # Create a 2D tensor + + # Select 1st index of first mode and keep all elements in second mode + sliced = slice_(tensor, (1, None)) + + .. code-block:: python + + # Select 2nd index of second mode and keep all elements in first mode + sliced = slice_(tensor, (None, 2)) + + Note: + - `None` represents keeping all elements in that mode + - Slicing preserves the layout/structure of the original object + - Can be used for: + * Extracting sub-tensors/sub-layouts + * Creating views into data + * Selecting specific patterns of elements + """ + + def lift_slice(a, b): + if isinstance(a, tuple): + if (not isinstance(b, tuple)) or (len(a) != len(b)): + raise ValueError("coord must be weakly congruent to src in slice_") + return reduce( + lambda p, q: p + q, (lift_slice(x, y) for x, y in zip(a, b)), () + ) + elif a is None: + return (b,) + else: + return () + + if is_integer(src) or isinstance(src, tuple): + if isinstance(coord, tuple): + if (not isinstance(src, tuple)) or (len(coord) != len(src)): + raise ValueError("coord must be weakly congruent to src in slice_") + return reduce( + lambda p, q: p + q, (lift_slice(x, y) for x, y in zip(coord, src)), () + ) + elif coord is None: + return src + else: + return () + + res_type = None + if isinstance(src, Tensor): + res_type = src.element_type + src = src.value + coord_val = _pack_coord(coord, loc=loc, ip=ip) + res = _cute_ir.slice(input=src, coord=coord_val, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def dice(src: Shape, coord: Coord, *, loc=None, ip=None) -> Shape: ... +@overload +def dice(src: Stride, coord: Coord, *, loc=None, ip=None) -> Stride: ... +@overload +def dice(src: Coord, coord: Coord, *, loc=None, ip=None) -> Coord: ... +@overload +def dice(src: IntTuple, coord: Coord, *, loc=None, ip=None) -> IntTuple: ... +@overload +def dice(src: Tile, coord: Coord, *, loc=None, ip=None) -> Tile: ... +@overload +def dice(src: Layout, coord: Coord, *, loc=None, ip=None) -> Layout: ... +@overload +def dice(src: ComposedLayout, coord: Coord, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +@lru_cache_ir() +def dice(src, dicer, *, loc=None, ip=None): + """Keep modes in input when it is paired with an integer in dicer. + + This function performs dicing operation on the input based on the dicer coordinate. + Dicing is a fundamental operation in CuTe that allows selecting specific modes from + a tensor or layout based on a coordinate pattern. + + :param dicer: A static coordinate indicating how to dice the input + :type dicer: Coord + :param input: The operand to be diced on + :type input: Union[IntTuple, Shape, Stride, Coord, Layout, ComposedLayout] + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: The diced result with selected modes from the input + :rtype: Union[IntTuple, Shape, Stride, Coord, Layout, ComposedLayout] + :raises TypeError: If dicer has an unsupported type + :raises ValueError: If input is not provided + + **Examples:** + + .. code-block:: python + + # Basic dicing of a layout + layout = make_layout((32,16,8)) + + # Keep only first and last modes + diced = dice((1,None,1), layout) + + Note: + - The dicer coordinate must be static + - Use underscore (_) to remove a mode + """ + if not is_static(dicer): + raise ValueError(f"expects dicer to be static, but got {dicer}") + + def lift_dice(a, b): + if isinstance(a, tuple): + if (not isinstance(b, tuple)) or (len(a) != len(b)): + raise ValueError("dicer must be weakly congruent to input in dice") + return reduce( + lambda p, q: p + q, (lift_dice(x, y) for x, y in zip(a, b)), () + ) + elif a is None: + return () + else: + return (b,) + + if is_integer(src) or isinstance(src, tuple): + if isinstance(dicer, tuple): + if (not isinstance(src, tuple)) or (len(dicer) != len(src)): + raise ValueError("dicer must be weakly congruent to src in dice") + return reduce( + lambda p, q: p + q, (lift_dice(x, y) for x, y in zip(dicer, src)), () + ) + elif dicer is None: + return () + else: + return src + + dicer_val = _pack_coord(dicer, loc=loc, ip=ip) + return _cute_ir.dice(src, dicer_val.type.attribute, loc=loc, ip=ip) + + +def wrap(x) -> tuple: + """ + Wraps the input into a tuple if not a tuple. + """ + if isinstance(x, tuple): + return x + return (x,) + + +def _extend(func, input, elem, up_to_rank, loc, ip): + if input is None: + raise ValueError(f"No input provided for input") + + if isinstance(input, (Layout, ComposedLayout)): + if elem is None: + elem = make_layout(1) + elif not isinstance(elem, Layout): + raise TypeError(f"Input type of elem ({type(elem)}) is not accepted!") + N = rank(input) + 1 if up_to_rank is None else up_to_rank + return func(N, input, elem, loc=loc, ip=ip) + + if is_valid_leaf(input) or isinstance(input, tuple): + if elem is None: + elem = 1 + if (not isinstance(elem, tuple)) and (not is_valid_leaf(elem)): + raise TypeError(f"Input type of elem ({type(elem)}) is not accepted!") + + input = wrap(input) + repeat_cnt = 1 if up_to_rank is None else up_to_rank - rank(input) + if repeat_cnt == 0: + return input + elif repeat_cnt < 0: + raise ValueError(f"up_to_rank must be >= rank(input)") + else: + if func is _cute_ir.prepend_to_rank: + return (elem,) * repeat_cnt + input + else: + return input + (elem,) * repeat_cnt + + raise TypeError(f"invalid type for input, got {type(input)}") + + +@overload +def prepend( + input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None +) -> Shape: ... +@overload +def prepend( + input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None +) -> Stride: ... +@overload +def prepend( + input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None +) -> Coord: ... +@overload +def prepend( + input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def prepend(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... +@overload +def prepend( + input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> Layout: ... +@overload +def prepend( + input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def prepend(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): + """Extend input to rank up_to_rank by prepending elem in front of input. + + This function extends the input object by prepending elements to reach a desired rank. + It supports various CuTe types including shapes, layouts, tensors etc. + + :param input: Source to be prepended to + :type input: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :param elem: Element to prepend to input + :type elem: Union[Shape, Stride, Coord, IntTuple, Tile, Layout] + :param up_to_rank: The target rank after extension, defaults to None + :type up_to_rank: Union[None, int], optional + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint] + :return: The extended result with prepended elements + :rtype: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :raises ValueError: If up_to_rank is less than input's current rank + :raises TypeError: If input or elem has unsupported type + + **Examples:** + + .. code-block:: python + + # Prepend to a Shape + shape = (4,4) + prepend(shape, 2) # Returns (2,4,4) + + # Prepend to a Layout + layout = make_layout((8,8)) + prepend(layout, make_layout((2,))) # Returns (2,8,8):(1,1,8) + + # Prepend with target rank + coord = (1,1) + prepend(coord, 0, up_to_rank=4) # Returns (0,0,1,1) + """ + return _extend(_cute_ir.prepend_to_rank, input, elem, up_to_rank, loc=loc, ip=ip) + + +@overload +def append( + input: Shape, elem: Shape, up_to_rank=None, *, loc=None, ip=None +) -> Shape: ... +@overload +def append( + input: Stride, elem: Stride, up_to_rank=None, *, loc=None, ip=None +) -> Stride: ... +@overload +def append( + input: Coord, elem: Coord, up_to_rank=None, *, loc=None, ip=None +) -> Coord: ... +@overload +def append( + input: IntTuple, elem: IntTuple, up_to_rank=None, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def append(input: Tile, elem: Tile, up_to_rank=None, *, loc=None, ip=None) -> Tile: ... +@overload +def append( + input: Layout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> Layout: ... +@overload +def append( + input: ComposedLayout, elem: Layout, up_to_rank=None, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def append(input, elem, up_to_rank: Union[None, int] = None, *, loc=None, ip=None): + """Extend input to rank up_to_rank by appending elem to the end of input. + + This function extends the input object by appending elements to reach a desired rank. + It supports various CuTe types including shapes, layouts, tensors etc. + + :param input: Source to be appended to + :type input: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :param elem: Element to append to input + :type elem: Union[Shape, Stride, Coord, IntTuple, Tile, Layout] + :param up_to_rank: The target rank after extension, defaults to None + :type up_to_rank: Union[None, int], optional + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint] + :return: The extended result with appended elements + :rtype: Union[Shape, Stride, Coord, IntTuple, Tile, Layout, ComposedLayout, Tensor] + :raises ValueError: If up_to_rank is less than input's current rank + :raises TypeError: If input or elem has unsupported type + + **Examples:** + + .. code-block:: python + + # Append to a Shape + shape = (4,4) + append(shape, 2) # Returns (4,4,2) + + # Append to a Layout + layout = make_layout((8,8)) + append(layout, make_layout((2,))) # Returns (8,8,2):(1,8,1) + + # Append with target rank + coord = (1,1) + append(coord, 0, up_to_rank=4) # Returns (1,1,0,0) + + Note: + - The function preserves the structure of the input while extending it + - Can be used to extend tensors, layouts, shapes and other CuTe types + - When up_to_rank is specified, fills remaining positions with elem + - Useful for tensor reshaping and layout transformations + """ + return _extend(_cute_ir.append_to_rank, input, elem, up_to_rank, loc=loc, ip=ip) + + +@dsl_user_op +def prepend_ones( + t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None +) -> Tensor: + return make_tensor( + t.iterator, prepend(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip + ) + + +@dsl_user_op +def append_ones( + t: Tensor, up_to_rank: Union[None, int] = None, *, loc=None, ip=None +) -> Tensor: + return make_tensor( + t.iterator, append(t.layout, make_layout(1), up_to_rank), loc=loc, ip=ip + ) + + +def repeat_like(x, target): + """Creates an object congruent to target and filled with x. + + This function recursively creates a nested tuple structure that matches the structure + of the target, with each leaf node filled with the value x. + + :param x: The value to fill the resulting structure with + :type x: Any + :param target: The structure to mimic + :type target: Union[tuple, Any] + :return: A structure matching target but filled with x + :rtype: Union[tuple, Any] + + **Examples:** + + .. code-block:: python + + repeat_like(0, (1, 2, 3)) # Returns (0, 0, 0) + repeat_like(1, ((1, 2), 3)) # Returns ((1, 1), 1) + repeat_like(2, 5) # Returns 2 + """ + if not isinstance(target, tuple): + return x + if not target: + return () + if len(target) == 1: + return (repeat_like(x, target[0]),) + return tuple(repeat_like(x, t) for t in target) + + +def flatten_to_tuple(a: Union[IntTuple, Coord, Shape, Stride]) -> tuple: + """Flattens a potentially nested tuple structure into a flat tuple. + + This function recursively traverses the input structure and flattens it into + a single-level tuple, preserving the order of elements. + + :param a: The structure to flatten + :type a: Union[IntTuple, Coord, Shape, Stride] + :return: A flattened tuple containing all elements from the input + :rtype: tuple + + **Examples:** + + .. code-block:: python + + flatten_to_tuple((1, 2, 3)) # Returns (1, 2, 3) + flatten_to_tuple(((1, 2), 3)) # Returns (1, 2, 3) + flatten_to_tuple((1, (2, (3,)))) # Returns (1, 2, 3) + """ + if not isinstance(a, tuple): + return wrap(a) + else: + return tuple(chain.from_iterable(tuple(flatten_to_tuple(x) for x in a))) + + +@overload +def flatten(a: Union[IntTuple, Coord, Shape, Stride]) -> IntTuple: ... +@overload +def flatten(a: Tensor) -> Tensor: ... +@overload +def flatten(a: Layout) -> Layout: ... + + +def flatten(a): + """Flattens a CuTe data structure into a simpler form. + + For tuples, this function flattens the structure into a single-level tuple. + For layouts, it returns a new layout with flattened shape and stride. + For tensors, it returns a new tensor with flattened layout. + For other types, it returns the input unchanged. + + :param a: The structure to flatten + :type a: Union[IntTuple, Coord, Shape, Stride, Layout, Tensor] + :return: The flattened structure + :rtype: Union[tuple, Any] + + **Examples:** + + .. code-block:: python + + flatten((1, 2, 3)) # Returns (1, 2, 3) + flatten(((1, 2), (3, 4))) # Returns (1, 2, 3, 4) + flatten(5) # Returns 5 + flatten(Layout(shape, stride)) # Returns Layout(flatten(shape), flatten(stride)) + flatten(Tensor(layout)) # Returns Tensor(flatten(layout)) + + """ + if isinstance(a, Tensor): + return make_tensor(a.iterator, flatten(a.layout)) + elif isinstance(a, Layout): + return make_layout(flatten(a.shape), stride=flatten(a.stride)) + elif isinstance(a, tuple): + return flatten_to_tuple(a) + else: + return a + + +def unflatten( + sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]], profile: XTuple +) -> XTuple: + """Unflatten a flat tuple into a nested tuple structure according to a profile. + + This function transforms a flat sequence of elements into a nested tuple structure + that matches the structure defined by the profile parameter. It traverses the profile + structure and populates it with elements from the sequence. + + sequence must be long enough to fill the profile. Raises RuntimeError if it is not. + + :param sequence: A flat sequence of elements to be restructured + :type sequence: Union[Tuple[Any, ...], List[Any], Iterable[Any]] + :param profile: A nested tuple structure that defines the shape of the output + :type profile: XTuple + :return: A nested tuple with the same structure as profile but containing elements from sequence + :rtype: XTuple + + Example: + >>> unflatten([1, 2, 3, 4], ((0, 0), (0, 0))) + ((1, 2), (3, 4)) + """ + + def _make_generator(): + for element in sequence: + yield element + + xs = _make_generator() + return transform_leaf(lambda _: next(xs), profile) + + +@dsl_user_op +def elem_less( + lhs: Union[Shape, IntTuple, Coord], + rhs: Union[Shape, IntTuple, Coord], + *, + loc=None, + ip=None, +): + lhs_val = _pack_coord(lhs, loc=loc, ip=ip) + rhs_val = _pack_coord(rhs, loc=loc, ip=ip) + return Boolean(_cute_ir.elem_less(lhs_val, rhs_val, loc=loc, ip=ip)) + + +@overload +def filter_zeros( + input: Layout, *, target_profile=None, loc=None, ip=None +) -> Layout: ... +@overload +def filter_zeros( + input: Tensor, *, target_profile=None, loc=None, ip=None +) -> Tensor: ... + + +@dsl_user_op +def filter_zeros(input, *, target_profile=None, loc=None, ip=None): + """Filter out zeros from a layout or tensor. + + This function removes zero-stride dimensions from a layout or tensor. + Refer to https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/02_layout_algebra.md + for more layout algebra operations. + + :param input: The input layout or tensor to filter + :type input: Layout or Tensor + :param target_profile: Target profile for the filtered result, defaults to None + :type target_profile: optional + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The filtered layout or tensor with zeros removed + :rtype: Layout or Tensor + :raises TypeError: If input is not a Layout or Tensor + """ + if not isinstance(input, (Layout, Tensor)): + raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") + if isinstance(input, Tensor): + input = input.value + return _cute_ir.filter_zeros(input, target_profile=target_profile, loc=loc, ip=ip) + + +@dsl_user_op +def filter(input: Union[Layout, Tensor], *, loc=None, ip=None): + """Filter a layout or tensor. + + This function filters a layout or tensor according to CuTe's filtering rules. + + :param input: The input layout or tensor to filter + :type input: Layout or Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The filtered layout or tensor + :rtype: Layout or Tensor + :raises TypeError: If input is not a Layout or Tensor + """ + if not isinstance(input, (Layout, Tensor)): + raise TypeError(f"Expect layout or tensor as input but got {type(input)=}") + if isinstance(input, _Tensor): + input = input.value + return _cute_ir.filter(input, loc=loc, ip=ip) + + +@dsl_user_op +def product(a: Union[IntTuple, Shape], *, loc=None, ip=None): + """Return product of the given IntTuple or Shape. + + Computes the product of all elements in the input tuple or shape. + Returns static value if type is static. + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: Static product of IntTuple or Shape if static, otherwise a Value + :rtype: int or Value + :raises TypeError: If input is not an IntTuple or Shape + """ + if is_integer(a): + return a + if isinstance(a, tuple): + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + res = _cute_ir.tuple_product(a_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + else: + raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") + + +@overload +def product_like( + a: IntTuple, target_profile: XTuple, *, loc=None, ip=None +) -> IntTuple: ... +@overload +def product_like(a: Shape, target_profile: XTuple, *, loc=None, ip=None) -> Shape: ... + + +@dsl_user_op +def product_like( + a: Union[IntTuple, Shape], target_profile: XTuple, *, loc=None, ip=None +): + """Return product of the given IntTuple or Shape at leaves of `target_profile`. + + This function computes products according to the structure defined by target_profile. + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param target_profile: The profile that guides how products are computed + :type target_profile: XTuple + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The resulting tuple with products computed according to target_profile + :rtype: IntTuple or Shape + :raises TypeError: If inputs have incompatible types + :raises ValueError: If inputs have incompatible shapes + """ + # Perform product at leaf of `target_profile` + if not isinstance(target_profile, tuple): + return product(a, loc=loc, ip=ip) + else: + if not isinstance(a, tuple): + raise TypeError(f"expects `a` tuple but got {a}") + + if len(a) != len(target_profile): + raise ValueError(f"expects `a` and `guide` have the same rank") + + return tuple( + product_like(x, g, loc=loc, ip=ip) for x, g in zip(a, target_profile) + ) + + +@overload +def product_each(a: IntTuple, *, loc=None, ip=None) -> IntTuple: ... +@overload +def product_each(a: Shape, *, loc=None, ip=None) -> Shape: ... + + +@dsl_user_op +def product_each(a, *, loc=None, ip=None): + """Compute products for each component of the input. + + Returns a rank(a) tuple `result` such that get(result, mode=[i]) == product(get(a, mode=[i])) + + :param a: The input tuple or shape + :type a: IntTuple or Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: A tuple containing products for each component + :rtype: tuple + :raises TypeError: If input is not an IntTuple or Shape + """ + if is_integer(a): + return a + if isinstance(a, tuple): + if not a: + return 1 + else: + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + res = _cute_ir.tuple_product_each(a_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + else: + raise TypeError(f"expects IntTuple or Shape, but got {type(a)}") + + +@dsl_user_op +def size( + a: Union[IntTuple, Shape, Layout, ComposedLayout, Tensor], + mode: List[int] = [], + *, + loc=None, + ip=None, +) -> Int: + """Return size of domain of layout or tensor. + + Computes the size (number of elements) in the domain of a layout or tensor. + For layouts, this corresponds to the shape of the coordinate space. + See https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/cute/01_layout.md + for more details on layout domains. + + :param a: The input object whose size to compute + :type a: IntTuple, Shape, Layout, ComposedLayout or Tensor + :param mode: List of mode(s) for size calculation. If empty, computes total size, defaults to [] + :type mode: list of int, optional + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: Static size of layout or tensor if static, otherwise a Value + :rtype: int or Value + :raises ValueError: If mode contains non-integer elements + """ + if any(not isinstance(m, int) for m in mode): + raise ValueError(f"expects integer elements in mode, but got {mode}") + + if isinstance(a, (TiledMma, TiledCopy)): + return a.size + a_val = None + if not isinstance(a, (Layout, ComposedLayout, Tensor)): + a_val = _pack_int_tuple(a, loc=loc, ip=ip) + elif isinstance(a, Tensor): + a_val = a.value + else: + a_val = a + + res = _cute_ir.size(a_val, mode=mode, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) # type: ignore + + +@dsl_user_op +def shape_div(lhs: Shape, rhs: Shape, *, loc=None, ip=None) -> Shape: + """Perform element-wise division of shapes. + + This function performs element-wise division between two shapes. + + :param lhs: Left-hand side shape + :type lhs: Shape + :param rhs: Right-hand side shape + :type rhs: Shape + :param loc: Source location for MLIR, defaults to None + :type loc: optional + :param ip: Insertion point, defaults to None + :type ip: optional + :return: The result of element-wise division + :rtype: Shape + """ + lhs = _pack_shape(lhs, loc=loc, ip=ip) + rhs = _pack_shape(rhs, loc=loc, ip=ip) + res = _cute_ir.shape_div(lhs, rhs, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@dsl_user_op +def ceil_div(input: Shape, tiler: Tiler, *, loc=None, ip=None) -> Shape: + """ + Compute the ceiling division of a target shape by a tiling specification. + + This function computes the number of tiles required to cover the target domain. + It is equivalent to the second mode of `zipped_divide(input, tiler)`. + + :param input: A tuple of integers representing the dimensions of the target domain. + :type input: Shape + :param tiler: The tiling specification. + :type tiler: Union[Layout, Shape, Tile] + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :return: A tuple of integers representing the number of tiles required along each dimension, + i.e. the result of the ceiling division of the input dimensions by the tiler dimensions. + :rtype: Shape + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + input = (10, 6) + tiler = (3, 4) + result = cute.ceil_div(input, tiler) + print(result) # Outputs: (4, 2) + """ + input_val = _pack_shape(input, loc=loc, ip=ip) + tiler_val = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.ceil_div(input=input_val, tiler=tiler_val, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +def round_up(a: IntTuple, b: IntTuple) -> IntTuple: + """ + Rounds up elements of a using elements of b. + """ + if isinstance(a, tuple): + if not a: + raise ValueError(f"inputs cannot be empty") + if not isinstance(b, tuple): + raise TypeError( + f"expects both inputs to be tuple, but got {type(a)} and {type(b)}" + ) + if rank(a) < rank(b): + raise ValueError( + f"expects rank(a) to be greater or equal than rank(b), but got {a}, {b}" + ) + b = append(b, 1, rank(a)) + return tuple(round_up(x, y) for x, y in zip(a, b)) + return ((a + b - 1) // b) * b + + +# +# Layout API (also used by tensors) +# + + +@dsl_user_op +def make_layout( + shape: Shape, *, stride: Union[Stride, None] = None, loc=None, ip=None +) -> Layout: + """Create a CuTe Layout object from shape and optional stride information. + + A Layout in CuTe represents the mapping between logical and physical coordinates of a tensor. + This function creates a Layout object that defines how tensor elements are arranged in memory. + + :param shape: Shape of the layout defining the size of each mode + :type shape: Shape + :param stride: Optional stride values for each mode, defaults to None + :type stride: Union[Stride, None] + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new Layout object with the specified shape and stride + :rtype: Layout + + **Examples:** + + .. code-block:: python + + # Create a 2D compact left-most layout with shape (4,4) + layout = make_layout((4,4)) # compact left-most layout + + # Create a left-most layout with custom strides + layout = make_layout((4,4), stride=(1,4)) # left-most layout with strides (1,4) + + # Create a layout for a 3D tensor + layout = make_layout((32,16,8)) # left-most layout + + # Create a layout with custom strides + layout = make_layout((2,2,2), stride=(4,1,2)) # layout with strides (4,1,2) + + Note: + - If stride is not provided, a default compact left-most stride is computed based on the shape + - The resulting layout maps logical coordinates to physical memory locations + - The layout object can be used for tensor creation and memory access patterns + - Strides can be used to implement: + * Row-major vs column-major layouts + * Padding and alignment + * Blocked/tiled memory arrangements + * Interleaved data formats + - Stride is keyword only argument to improve readability, e.g. + * make_layout((3,4), (1,4)) can be confusing with make_layout(((3,4), (1,4))) + * make_layout((3,4), stride=(1,4)) is more readable + """ + if stride is not None and not is_congruent(shape, stride): + raise ValueError(f"shape and stride must be congruent") + + shape_val = _pack_shape(shape, loc=loc, ip=ip) + if stride is not None: + stride_val = _pack_stride(stride, loc=loc, ip=ip) + layout_ty = _cute_ir.LayoutType.get(shape_val, stride_val) + else: + stride_val = None + layout_ty = _cute_ir.LayoutType.get(shape_val) + + return _cute_ir.make_layout( + layout_ty, shape=shape_val, stride=stride_val, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_identity_layout(shape: Shape, *, loc=None, ip=None) -> Layout: + """Create an identity layout with the given shape. + + An identity layout maps logical coordinates directly to themselves without any transformation. + This is equivalent to a layout with stride (1@0,1@1,...,1@(N-1)). + + :param shape: The shape of the layout + :type shape: Shape + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new identity Layout object with the specified shape + :rtype: Layout + + **Examples:** + + .. code-block:: python + + # Create a 2D identity layout with shape (4,4) + layout = make_identity_layout((4,4)) # stride=(1@0,1@1) + + # Create a 3D identity layout + layout = make_identity_layout((32,16,8)) # stride=(1@0,1@1,1@2) + + Note: + - An identity layout is a special case where each coordinate maps to itself + - Useful for direct coordinate mapping without any transformation + """ + if not is_int_tuple(shape): + raise TypeError(f"expects a shape input, got {type(shape)}") + shape_val = _pack_shape(shape, loc=loc, ip=ip) + return _cute_ir.make_identity_layout(shape_val, loc=loc, ip=ip) + + +@dsl_user_op +def make_ordered_layout(shape: Shape, order: Shape, *, loc=None, ip=None) -> Layout: + """Create a layout with a specific ordering of dimensions. + + This function creates a layout where the dimensions are ordered according to the + specified order parameter, allowing for custom dimension ordering in the layout. + + :param shape: The shape of the layout + :type shape: Shape + :param order: The ordering of dimensions + :type order: Shape + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new Layout object with the specified shape and dimension ordering + :rtype: Layout + + **Examples:** + + .. code-block:: python + + # Create a row-major layout + layout = make_ordered_layout((4,4), order=(1,0)) + + # Create a column-major layout + layout = make_ordered_layout((4,4), order=(0,1)) # stride=(1,4) + + # Create a layout with custom dimension ordering for a 3D tensor + layout = make_ordered_layout((32,16,8), order=(2,0,1)) # stride=(128,1,16) + + Note: + - The order parameter specifies the ordering of dimensions from fastest-varying to slowest-varying + - For a 2D tensor, (0,1) creates a column-major layout, while (1,0) creates a row-major layout + - The length of order must match the rank of the shape + """ + shape_val = _pack_shape(shape, loc=loc, ip=ip) + order_val = _pack_int_tuple(order, loc=loc, ip=ip) + return _cute_ir.make_ordered_layout( + shape=shape_val, order=order_val, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_composed_layout( + inner, offset: IntTuple, outer: Layout, *, loc=None, ip=None +) -> ComposedLayout: + """Create a composed layout by composing an inner transformation with an outer layout. + + A composed layout applies a sequence of transformations + to coordinates. The composition is defined as (inner ∘ offset ∘ outer), where the operations + are applied from right to left. + + :param inner: The inner transformation (can be a Layout or Swizzle) + :type inner: Union[Layout, Swizzle] + :param offset: An integral offset applied between transformations + :type offset: IntTuple + :param outer: The outer (right-most) layout that is applied first + :type outer: Layout + :param loc: Source location information, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for IR generation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A new ComposedLayout representing the composition + :rtype: ComposedLayout + + **Examples:** + + .. code-block:: python + + # Create a basic layout + inner = make_layout(...) + outer = make_layout((4,4), stride=(E(0), E(1))) + + # Create a composed layout with an offset + composed = make_composed_layout(inner, (2,0), outer) + + Note: + - The composition applies transformations in the order: outer → offset → inner + - The stride divisibility condition must be satisfied for valid composition + - Certain compositions (like Swizzle with scaled basis) are invalid and will raise errors + - Composed layouts inherit many properties from the outer layout + """ + if not isinstance(outer, Layout): + raise TypeError( + f"expects the outer (or right-most or effectively visible) layout to be an affine layout, but got {outer}" + ) + if isinstance(inner, Swizzle) and has_scaled_basis(outer.stride): + raise TypeError(f"invalid composition {inner} o {offset} o {outer}") + offset_val = _pack_int_tuple(offset, loc=loc, ip=ip) + return _cute_ir.make_composed_layout(inner, offset_val, outer, loc=loc, ip=ip) + + +@dsl_user_op +def cosize( + a: Union[Layout, ComposedLayout, Tensor], mode: List[int] = [], *, loc=None, ip=None +): + """Return size of codomain of layout or tensor. Return static value if type is static. + + :param a: Layout, ComposedLayout, or Tensor object + :type a: Union[Layout, ComposedLayout, Tensor] + :param mode: List of mode(s) for cosize calculation + :type mode: List[int], optional + :param loc: Location information for diagnostics, defaults to None + :type loc: optional + :param ip: Instruction pointer for diagnostics, defaults to None + :type ip: optional + :return: Static size of layout or tensor (fast fold) if static, or a dynamic Value + :rtype: Union[int, Value] + """ + if any(not is_static(m) for m in mode): + raise ValueError(f"expects static mode, but got {mode}") + + if isinstance(a, _Tensor): + a = a.value + res = _cute_ir.cosize(a, mode=mode, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@dsl_user_op +def size_in_bytes( + dtype: Type[Numeric], layout: Union[Layout, ComposedLayout], *, loc=None, ip=None +): + """Calculate the size in bytes based on its data type and layout. + + :param dtype: The DSL numeric data type + :type dtype: Type[Numeric] + :param layout: The layout of the elements. If None, the function returns 0 + :type layout: Layout, optional + :param loc: Location information for diagnostics, defaults to None + :type loc: optional + :param ip: Instruction pointer for diagnostics, defaults to None + :type ip: optional + :return: The total size in bytes. Returns 0 if the layout is None + :rtype: int + """ + if not isinstance(dtype, NumericMeta): + raise TypeError(f"dtype must be a Numeric, but got {dtype}") + + if layout is None: + return 0 + elif isinstance(layout, ComposedLayout): + if not isinstance(layout.inner, Swizzle): + raise TypeError( + f"invalid composed layout {layout}, inner must be a Swizzle" + ) + else: + return cosize(layout.outer, loc=loc, ip=ip) * dtype.width // 8 + else: + return cosize(layout, loc=loc, ip=ip) * dtype.width // 8 + + +@dsl_user_op +def coalesce(input, *, target_profile: Coord = None, loc=None, ip=None): + if target_profile: + profile_val = _pack_coord(target_profile, loc=loc, ip=ip) + return _cute_ir.coalesce(input, target_profile=profile_val, loc=loc, ip=ip) + else: + return _cute_ir.coalesce(input, loc=loc, ip=ip) + + +@dsl_user_op +def crd2idx(coord: Coord, layout, *, loc=None, ip=None): + """ + Convert a multi-dimensional coordinate into a value using the specified layout. + + This function computes the inner product of the flattened coordinate and stride: + + index = sum(flatten(coord)[i] * flatten(stride)[i] for i in range(len(coord))) + + :param coord: A tuple or list representing the multi-dimensional coordinate + (e.g., (i, j) for a 2D layout). + :type coord: Coord + :param layout: A layout object that defines the memory storage layout, including shape and stride, + used to compute the inner product. + :type layout: Layout or ComposedLayout + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :returns: The result of applying the layout transformation to the provided coordinate. + :rtype: Any type that the layout maps to + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + L = cute.make_layout((5, 4), stride=(4, 1)) + idx = cute.crd2idx((2, 3), L) + # Computed as: 2 * 4 + 3 = 11 + print(idx) + foo() # Expected output: 11 + """ + coord_val = _pack_coord(coord, loc=loc, ip=ip) + if isinstance(layout, (tuple, int)): + layout = make_layout(layout, loc=loc, ip=ip) + + res = _cute_ir.crd2idx(coord_val, layout, loc=loc, ip=ip) + return _unpack_x_tuple(res, loc=loc, ip=ip) + + +@dsl_user_op +def recast_layout(new_type_bits, old_type_bits, src_layout, *, loc=None, ip=None): + return _cute_ir.recast_layout( + new_type_bits, old_type_bits, src_layout, loc=loc, ip=ip + ) + + +@dsl_user_op +def slice_and_offset(coord, src, *, loc=None, ip=None): + layout = slice_(src, coord, loc=loc, ip=ip) + offset = crd2idx(coord, src, loc=loc, ip=ip) + return layout, offset + + +@dsl_user_op +@lru_cache_ir() +def shape( + input: Union[Shape, Tensor, Layout, Tile], *, mode=None, loc=None, ip=None +) -> Shape: + """Returns the shape of a tensor, layout or tiler. + + For shapes, this function is identical to get. + + This function extracts the shape information from the input object. For tensors and layouts, + it returns their internal shape property. For tilers, it unpacks the shape from the tile + representation. + + :param input: The object to extract shape from + :type input: Union[Tensor, Layout, Tile] + :param mode: Optional mode selector to extract specific dimensions from the shape + :type mode: Optional[int] + :param loc: Source location for MLIR operation tracking + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation + :type ip: Optional[InsertionPoint] + :return: The shape of the input object, optionally filtered by mode + :rtype: Shape + + Example: + + .. code-block:: python + + # Get shape of a layout + l0 = cute.make_layout((2, 3, 4)) + s0 = cute.shape(l0) # => (2, 3, 4) + + # Get shape of a hierarchical tiler + l1 = cute.make_layout(1) + s1 = cute.shape((l0, l1)) # => ((2, 3, 4), 1) + + # Get specific mode from a shape + s2 = cute.shape(l0, mode=0) # => 2 + """ + if is_int_tuple(input): + return get(input, mode=mode) + + if isinstance(input, (Tensor, Layout)): + shp = input.shape + else: + val = _cute_ir.get_shape(_pack_tile(input, loc=loc, ip=ip)) + shp = _unpack_x_tuple(val, loc=loc, ip=ip) + return get(shp, mode=mode) + + +# +# Pointer API +# + + +@dsl_user_op +def recast_ptr( + ptr: Pointer, + swizzle_=None, + dtype: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Pointer: + if dtype is not None: + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") + dtype = dtype.mlir_type + + value_type = ptr.type.value_type if dtype is None else dtype + swizzle = swizzle_.type.attribute if swizzle_ is not None else None + res_ty = _cute_ir.PtrType.get(value_type, ptr.memspace, ptr.alignment, swizzle) + return _cute_ir.recast_iter(res_ty, ptr.value, loc=loc, ip=ip) + + +@dsl_user_op +def make_ptr( + dtype: Union[Type[Numeric], None], + value, + mem_space: AddressSpace = AddressSpace.generic, + *, + assumed_align=None, + loc=None, + ip=None, +) -> Pointer: + if dtype is None or not isinstance(dtype, NumericMeta): + raise TypeError(f"expects dtype to be a type of Numeric, but got {dtype}") + + if not isinstance(mem_space, AddressSpace): + raise TypeError(f"expects mem_space to be an AddressSpace, but got {mem_space}") + + if isinstance(value, ir.Value) and llvm.PointerType.isinstance(value.type): + value = llvm.ptrtoint(T.i64(), value) + + if not is_integer(value): + raise TypeError(f"expects integer value, but got {type(value)}") + value = Int32(value) if mem_space == AddressSpace.tmem else Int64(value) + + bytes_per_elt = max(1, dtype.width // 8) + if assumed_align is None: + assumed_align = bytes_per_elt + + if bytes_per_elt % assumed_align != 0 and assumed_align % bytes_per_elt != 0: + raise ValueError( + f"{bytes_per_elt=} is not a multiple of {assumed_align=} and vice versa." + ) + + aligned_ty = _cute_ir.ConstrainedIntType.get(assumed_align, type(value).width) + aligned_intptr = _cute_ir.assume(aligned_ty, value.ir_value(), loc=loc, ip=ip) + + data_ty = T.i8() if dtype is None else dtype.mlir_type + ptr_ty = _cute_ir.PtrType.get(data_ty, mem_space, assumed_align) + return _cute_ir.inttoptr(ptr_ty, aligned_intptr, loc=loc, ip=ip) + + +# +# Tensor API +# + + +@dsl_user_op +def make_tensor( + iterator, layout: Union[Shape, Layout, ComposedLayout], *, loc=None, ip=None +) -> Tensor: + """Creates a tensor by composing an engine (iterator/pointer) with a layout. + + A tensor is defined as T = E ∘ L, where E is an engine (array, pointer, or counting iterator) + and L is a layout that maps logical coordinates to physical offsets. The tensor + evaluates coordinates by applying the layout mapping and dereferencing the engine + at the resulting offset. + + :param iterator: Engine component (pointer, iterator, or counting iterator) that provides + data access capabilities + :type iterator: Union[Pointer, IntTuple] + :param layout: Layout component that defines the mapping from logical coordinates to + physical offsets + :type layout: Union[Shape, Layout, ComposedLayout] + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A tensor object representing the composition E ∘ L + :rtype: Tensor + + :raises ValueError: If iterator type is not supported + + **Examples:** + + .. code-block:: python + + # Create a tensor with row-major layout + layout = make_layout((64, 128), stride=(128, 1)) + tensor = make_tensor(ptr, layout) + + # Create a tensor with hierarchical layout + layout = make_layout(((128, 8), (1, 4, 1)), stride=((32, 1), (0, 8, 4096))) + tensor = make_tensor(smem_ptr, layout) + + # Create a coord tensor + layout = make_layout(2, stride=16 * E(0)) + tensor = make_tensor(5, layout) + + Notes: + - The engine (iterator) must support random access operations + - Common engine types include raw pointers, arrays, and random-access iterators + - The layout defines both the shape (logical dimensions) and stride (physical mapping) + - Supports both direct coordinate evaluation T(c) and partial evaluation (slicing) + """ + if not isinstance(layout, (Layout, ComposedLayout)): + layout = make_layout(layout, loc=loc, ip=ip) + elif isinstance(layout, ComposedLayout) and layout.type.is_normal_layout: + layout = layout.outer + + ty = None + if is_integer(iterator) or isinstance(iterator, tuple): + iterator = _pack_int_tuple(iterator, loc=loc, ip=ip) + ty = _cute_ir.CoordTensorType.get(iterator.type, layout.type) + elif isinstance(iterator, Pointer): + iterator = iterator.value + ty = _cute_ir.MemRefType.get(iterator.type, layout.type) + else: + raise TypeError(f"unsupported iterator type, got {type(iterator)}") + + return _cute_ir.make_view(result=ty, iter=iterator, layout=layout, loc=loc, ip=ip) + + +@dsl_user_op +def make_identity_tensor(shape: Shape, *, loc=None, ip=None) -> Tensor: + """Creates an identity tensor with the given shape. + + An identity tensor maps each coordinate to itself, effectively creating a counting + sequence within the shape's bounds. This is useful for generating coordinate indices + or creating reference tensors for layout transformations. + + :param shape: The shape defining the tensor's dimensions. Can be a simple integer + sequence or a hierarchical structure ((m,n),(p,q)) + :type shape: Shape + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: A tensor that maps each coordinate to itself + :rtype: Tensor + + **Examples:** + + .. code-block:: python + + # Create a simple 1D coord tensor + tensor = make_identity_tensor(6) # [0,1,2,3,4,5] + + # Create a 2D coord tensor + tensor = make_identity_tensor((3,2)) # [(0,0),(1,0),(2,0),(0,1),(1,1),(2,1)] + + # Create hierarchical coord tensor + tensor = make_identity_tensor(((2,1),3)) + # [((0,0),0),((1,0),0),((0,0),1),((1,0),1),((0,0),2),((1,0),2)] + + Notes: + - The shape parameter follows CuTe's IntTuple concept + - Coordinates are ordered colexicographically + - Useful for generating reference coordinates in layout transformations + """ + shape_val = _pack_shape(shape, loc=loc, ip=ip) + return _cute_ir.make_identity_tensor(shape_val, loc=loc, ip=ip) + + +@dsl_user_op +def make_fragment( + layout_or_shape: Union[Layout, Shape], + dtype: Type[Numeric], + *, + loc=None, + ip=None, +) -> Tensor: + if not issubclass(dtype, Numeric): + raise TypeError(f"value_type must be a type of Numeric, but got {type(dtype)}") + elem_ty = dtype.mlir_type if dtype is not Boolean else T.i8() + + # Alignment for register memory is useless(?), pick-up large enough number + # to allow .128 (> 16B) load store + alignment = 32 + layout = None + if not isinstance(layout_or_shape, Layout): + layout = make_layout(layout_or_shape, loc=loc, ip=ip) + else: + layout = layout_or_shape + + ptr_ty = _cute_ir.PtrType.get(elem_ty, AddressSpace.rmem, alignment) + res_ty = _cute_ir.MemRefType.get(ptr_ty, layout.type) + tensor = _cute_ir.memref_alloca(res_ty, layout=layout, loc=loc, ip=ip) + return _Tensor(tensor.value, dtype) + + +@overload +def make_fragment_like( + src: Tensor, dtype: Optional[Type[Numeric]], *, loc=None, ip=None +) -> Tensor: ... + + +@overload +def make_fragment_like(src: Layout, *, loc=None, ip=None) -> Layout: ... + + +@overload +def make_fragment_like(src: ComposedLayout, *, loc=None, ip=None) -> ComposedLayout: ... + + +@dsl_user_op +def make_fragment_like(src, dtype=None, *, loc=None, ip=None): + """Create tensor with a compact layout in the same shape as the source on stack. + + This function either creates a fragment tensor with compact layout in + same shape as the source layout or a new layout with the same shape as the source. + The strides of the new layout follow the order induced by the source's strides, with a + special handling of the 0th mode: it is always stride-1 and generated in column-major order + (LayoutLeft). + + :param src: The source layout or tensor whose shape will be matched + :type src: Union[Layout, ComposedLayout, Tensor] + :param dtype: The element type for the fragment tensor, defaults to None + :type dtype: Type[Numeric], optional + :param loc: Source location for MLIR operations, defaults to None + :type loc: Location, optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: InsertionPoint, optional + + :return: A new layout or fragment tensor with matching shape + :rtype: Union[Layout, Tensor] + + **Examples:** + + Creating a rmem tensor from a tensor: + + .. code-block:: python + + smem_tensor = cute.make_tensor(smem_ptr, layout) + frag_tensor = cute.make_fragment_like(smem_tensor, cutlass.Float32) + # frag_tensor will be a register-backed tensor with the same shape + + Creating a fragment with a different element type: + + .. code-block:: python + + tensor = cute.make_tensor(gmem_ptr, layout) + bool_frag = cute.make_fragment_like(tensor, cutlass.Boolean) + # bool_frag will be a register-backed tensor with Boolean elements + + **Notes** + + - When used with a Tensor, if a type is provided, it will create a new + fragment tensor with that element type. + - For layouts with ScaledBasis strides, the function creates a fragment + from the shape only. + - This function is commonly used in GEMM and other tensor operations to + create register storage for intermediate results. + + """ + if isinstance(src, (Layout, ComposedLayout)): + new_layout = None + # Create base fragment layout + if isinstance(src, Layout) and has_scaled_basis(src.stride): + # For scaled basis strides, create fragment from shape only + new_layout = _cute_ir.make_fragment_like( + make_layout(src.shape), loc=loc, ip=ip + ) + else: + # Otherwise use full source layout + new_layout = _cute_ir.make_fragment_like(src, loc=loc, ip=ip) + if dtype is not None: + # call make_fragment to convert layout to tensor + return make_fragment(new_layout, dtype, loc=loc, ip=ip) + else: + return new_layout + elif isinstance(src, Tensor): + if isinstance(src.type, _cute_ir.CoordTensorType): + if dtype is None: + raise ValueError( + "dtype must be provided when src is a coordinate tensor" + ) + + new_layout = _cute_ir.make_fragment_like( + make_layout(src.shape), loc=loc, ip=ip + ) + return make_fragment(new_layout, dtype, loc=loc, ip=ip) + else: + dtype = src.element_type if dtype is None else dtype + ty = dtype.mlir_type if dtype is not Boolean else T.i8() + new_tensor = _cute_ir.make_fragment_like( + src.value, elem_type=ty, loc=loc, ip=ip + ) + return _Tensor(new_tensor.value, dtype) + else: + raise TypeError( + f"src must be a Layout or ComposedLayout or tensor, got {type(src)}" + ) + + +@dsl_user_op +def recast_tensor( + src: Tensor, dtype: Type[Numeric], swizzle_=None, *, loc=None, ip=None +): + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {dtype}") + + if dtype is Boolean: + dst_width = 8 + else: + dst_width = dtype.width + + if src.element_type is Boolean: + src_width = 8 + else: + src_width = src.element_type.width + + src_iter = recast_ptr(src.iterator, dtype=dtype, loc=loc, ip=ip) + src_layout = recast_layout(dst_width, src_width, src.layout, loc=loc, ip=ip) + return make_tensor(src_iter, src_layout, loc=loc, ip=ip) + + +@dsl_user_op +def domain_offset(coord: Coord, tensor: Tensor, *, loc=None, ip=None) -> Tensor: + offset = crd2idx(coord, tensor.layout, loc=loc, ip=ip) + if isinstance(tensor.iterator, Pointer): + return make_tensor(tensor.iterator + offset, tensor.layout) + elif is_integer(tensor.iterator) or isinstance(tensor.iterator, tuple): + new_iter = _cute_ir.add_offset( + _pack_int_tuple(tensor.iterator), _pack_int_tuple(offset) + ) + return make_tensor(_unpack_x_tuple(new_iter), tensor.layout) + else: + raise ValueError(f"unsupported tensor for domain_offset, got {tensor}") + + +# +# Layout algebra +# + + +@overload +def composition( + lhs: Layout, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None +) -> Layout: ... + + +@overload +def composition( + lhs: Tensor, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None +) -> Tensor: ... + + +@dsl_user_op +def composition(lhs, rhs: Union[Layout, Shape, Tile], *, loc=None, ip=None): + """ + Compose two layout representations using the CuTe layout algebra. + + Compose a left-hand layout (or tensor) with a right-hand operand into a new layout R, such that + for every coordinate c in the domain of the right-hand operand, the composed layout satisfies: + + R(c) = A(B(c)) + + where A is the left-hand operand provided as ``lhs`` and B is the right-hand operand provided as + ``rhs``. In this formulation, B defines the coordinate domain while A applies its transformation to + B's output, and the resulting layout R inherits the stride and shape adjustments from A. + + Satisfies: + cute.shape(cute.composition(lhs, rhs)) is compatible with cute.shape(rhs) + + :param lhs: The left-hand operand representing the transformation to be applied. + :type lhs: Layout or Tensor + :param rhs: The right-hand operand defining the coordinate domain. If provided as an int or tuple, + it will be converted to a tile layout. + :type rhs: Layout, Shape, or Tile, or int or tuple + :param loc: Optional location information for IR diagnostics. + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions. + :type ip: optional + :returns: A new composed layout R, such that for all coordinates c in the domain of ``rhs``, + R(c) = lhs(rhs(c)). + :rtype: Layout or Tensor + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + # Create a layout that maps (i,j) to i*4 + j + L1 = cute.make_layout((2, 3), stride=(4, 1)) + # Create a layout that maps (i,j) to i*3 + j + L2 = cute.make_layout((3, 4), stride=(3, 1)) + # Compose L1 and L2 + L3 = cute.composition(L1, L2) + # L3 now maps coordinates through L2 then L1 + """ + rhs_val = rhs + if not isinstance(rhs, Layout) and isinstance(rhs, (int, tuple)): + rhs_val = _pack_tile(rhs, loc=loc, ip=ip) + if isinstance(lhs, _Tensor): + lhs = lhs.value + return _cute_ir.composition(lhs, rhs_val, loc=loc, ip=ip) + + +@dsl_user_op +def complement( + input: Layout, cotarget: Union[Layout, Shape], *, loc=None, ip=None +) -> Layout: + """ + Compute the complement layout of the input layout with respect to the cotarget. + + The complement of a layout A with respect to cotarget n is a layout A* such that + for every k in Z_n and c in the domain of A, there exists a unique c* in the domain + of A* where k = A(c) + A*(c*). + + This operation is useful for creating layouts that partition a space in complementary ways, + such as row and column layouts that together cover a matrix. + + :param input: The layout to compute the complement of + :type input: Layout + :param cotarget: The target layout or shape that defines the codomain + :type cotarget: Union[Layout, Shape] + :param loc: Optional location information for IR diagnostics + :type loc: optional + :param ip: Optional instruction pointer or context for underlying IR functions + :type ip: optional + :returns: The complement layout + :rtype: Layout + + Example: + + .. code-block:: python + + import cutlass.cute as cute + @cute.jit + def foo(): + # Create a right-major layout for a 4x4 matrix + row_layout = cute.make_layout((4, 4), stride=(4, 1)) + # Create a left-major layout that complements the row layout + col_layout = cute.complement(row_layout, 16) + # The two layouts are complementary under 16 + """ + if isinstance(cotarget, Layout): + return _cute_ir.complement(input, cotarget=cotarget, loc=loc, ip=ip) + else: + cotarget_val = _pack_shape(cotarget, loc=loc, ip=ip) + return _cute_ir.complement(input, cotarget=cotarget_val, loc=loc, ip=ip) + + +@dsl_user_op +def right_inverse(input: Layout, *, loc=None, ip=None) -> Layout: + if not isinstance(input, Layout): + raise TypeError(f"expects input of type Layout, but got {type(input)}") + return _cute_ir.right_inverse(input=input, loc=loc, ip=ip) + + +@dsl_user_op +def left_inverse(input: Layout, *, loc=None, ip=None) -> Layout: + if not isinstance(input, Layout): + raise TypeError(f"expects input of type Layout, but got {type(input)}") + return _cute_ir.left_inverse(input=input, loc=loc, ip=ip) + + +@overload +def logical_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def logical_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def logical_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.logical_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def zipped_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def zipped_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def zipped_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.zipped_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def tiled_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def tiled_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def tiled_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.tiled_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def flat_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def flat_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def flat_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.flat_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def raked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def raked_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def raked_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.raked_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def blocked_product(block: Layout, tiler: Layout, *, loc=None, ip=None) -> Layout: ... +@overload +def blocked_product( + block: ComposedLayout, tiler: Layout, *, loc=None, ip=None +) -> ComposedLayout: ... + + +@dsl_user_op +def blocked_product(block, tiler: Layout, *, loc=None, ip=None): + return _cute_ir.blocked_product(input=block, tiler=tiler, loc=loc, ip=ip) + + +@overload +def logical_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def logical_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def logical_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.logical_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def zipped_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def zipped_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def zipped_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.zipped_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def tiled_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def tiled_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def tiled_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.tiled_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +@overload +def flat_divide(target: Layout, tiler: Tiler, *, loc=None, ip=None) -> Layout: ... +@overload +def flat_divide(target: Tensor, tiler: Tiler, *, loc=None, ip=None) -> Tensor: ... + + +@dsl_user_op +def flat_divide(target, tiler: Tiler, *, loc=None, ip=None): + res_type = None + if isinstance(target, _Tensor): + res_type = target.element_type + target = target.value + if isinstance(tiler, tuple): + tiler = _pack_tile(tiler, loc=loc, ip=ip) + res = _cute_ir.flat_divide(input=target, tiler=tiler, loc=loc, ip=ip) + return _Tensor(res, dtype=res_type) if isinstance(res, _Tensor) else res + + +# +# Higher-level utilties +# + + +@dsl_user_op +def max_common_layout( + a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None +) -> Layout: + a_layout = a.layout if isinstance(a, _Tensor) else a + b_layout = b.layout if isinstance(b, _Tensor) else b + + inv_b = right_inverse(b_layout, loc=loc, ip=ip) + common = coalesce(composition(a_layout, inv_b, loc=loc, ip=ip), loc=loc, ip=ip) + + # some_ir_value == 1 generates a new IR Value which evaluates to True! + s = get(common.shape, mode=[0], loc=loc, ip=ip) + d = get(common.stride, mode=[0], loc=loc, ip=ip) + # Keep only the static identity component of the common layout + if isinstance(s, int) and isinstance(d, int) and d == 1: + # Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, get(common, mode=[0], loc=loc, ip=ip), loc=loc, ip=ip) + else: + return make_layout(1, stride=0, loc=loc, ip=ip) + + +@dsl_user_op +def max_common_vector( + a: Union[Layout, Tensor], b: Union[Layout, Tensor], *, loc=None, ip=None +) -> int: + a_layout = a.layout if isinstance(a, _Tensor) else a + b_layout = b.layout if isinstance(b, _Tensor) else b + + inv_b = right_inverse(b_layout, loc=loc, ip=ip) + common = coalesce(composition(a_layout, inv_b, loc=loc, ip=ip), loc=loc, ip=ip) + + # Keep only the static identity component of the common layout + if ( + is_static(get(common.shape, mode=[0], loc=loc, ip=ip)) + and get(common.stride, mode=[0], loc=loc, ip=ip) == 1 + ): + # Truncate to the size of the contiguous vector (static stride-1 mode) + return get(common.shape, mode=[0], loc=loc, ip=ip) + else: + return 1 + + +@dsl_user_op +def tile_to_shape( + atom: Union[Layout, ComposedLayout], + trg_shape: Shape, + order: Shape, + *, + loc=None, + ip=None, +) -> Union[Layout, ComposedLayout]: + trg_shape = _pack_shape(shape(trg_shape), loc=loc, ip=ip) + order = _pack_int_tuple(order, loc=loc, ip=ip) + return _cute_ir.tile_to_shape(atom, trg_shape, order, loc=loc, ip=ip) + + +@dsl_user_op +def local_partition( + target: Tensor, + tiler: Union[Layout, Shape], + index: Union[int, Numeric], + proj: XTuple = 1, + *, + loc=None, + ip=None, +) -> Tensor: + if isinstance(index, cutlass_arith.ArithValue): + index_val = index + else: + index_val = index.ir_value() + if index_val.type.width > 32: + raise NotImplementedError( + f"Index value should be 32-bit or smaller integer type, but got {index_val.type}" + ) + return _cute_ir.local_partition( + input=target.value, tiler=dice(tiler, proj), index=index_val, loc=loc, ip=ip + ) + + +@dsl_user_op +def local_tile( + input: Tensor, + tiler: Union[Layout, Shape], + coord: Coord, + proj: XTuple = None, + *, + loc=None, + ip=None, +) -> Tensor: + tiler_val = _pack_shape(tiler, loc=loc, ip=ip) + coord_val = _pack_coord(coord, loc=loc, ip=ip) + if proj is not None: + if not isinstance(proj, tuple): + raise TypeError(f"Expects tuple for proj, but got {type(proj)}") + proj_val = _pack_coord(proj, loc=loc, ip=ip) + proj = proj_val.type.attribute + + return _cute_ir.local_tile( + input=input.value, + tile=tiler_val, + static_tile=None, + coord=coord_val, + static_coord=None, + proj=proj, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_layout_image_mask( + lay: Layout, coord: Coord, mode: int, *, loc=None, ip=None +) -> Int16: + """ + Makes a 16-bit integer mask of the image of a layout sliced at a given mode + and accounting for the offset given by the input coordinate for the other modes. + """ + if not is_static(lay): + raise ValueError( + f"make_layout_image_mask requires the layout to be static, but got {pretty_str(lay)}" + ) + r = rank(lay) + if rank(coord) != r: + raise ValueError( + f"the rank of the coordinate must be equal to the one of the layout, but got {pretty_str(coord)}" + ) + if mode > r or mode < 0: + raise ValueError(f"expects `mode` to be in [0,rank(lay)), but got {mode}") + # Given that we require the layout to be static, we can check that the mask fits in 16 bits + # This might be too conservative but safe + if cosize(lay) > 16: + raise ValueError("the mask may not fit into a 16-bit integer") + + # Replace the mode to keep with _ in the coordinate + slicer = tuple(None if idx == mode else x for idx, x in enumerate(coord)) + # Slice the layout with the slicer above and keep track of the offset + sliced_lay, offset = slice_and_offset(slicer, lay, loc=loc, ip=ip) + # Given that we replace only one mode with _, the rank of the slice should be 1 + assert rank(sliced_lay) == 1 + + # Create the mask of the image + mcast_mask = Int16(0) + for i in range(size(sliced_lay)): + mcast_mask = mcast_mask | (1 << sliced_lay(i)) + mcast_mask <<= offset + return Int16(mcast_mask) + + +#################################################################################################### +# +# Atom +# +#################################################################################################### + + +class Op(ABC): + """ + Operation abstract base class. + """ + + pass + + +class MmaOp(Op): + """ + MMA Operation abstract base class. + """ + + @abstractmethod + def _make_trait(self, *, loc=None, ip=None, **kwargs): + pass + + +class CopyOp(Op): + """ + Copy Operation abstract base class. + """ + + @abstractmethod + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ): + pass + + +class Trait(ABC): + """ + Trait abstract base class. + + Traits are internal-only classes used by Atoms that wrap the underlying IR Value. The Python + user should only interact with Ops and Atoms. + """ + + def __init__(self, value: ir.Value) -> None: + self.value = value + + def __extract_mlir_values__(self): + return [self.value] + + def __new_from_mlir_values__(self, values): + return self.__class__(values[0]) + + def set(self, field, value, *, loc=None, ip=None) -> None: + raise NotImplementedError( + "set not implemented, the requesting Atom has likely no runtime state" + ) + + def unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + return self.value + + +class Atom(ABC): + """ + Atom base class. + + An Atom is the composition of + + - a MMA or Copy Operation; + - an internal MMA or Copy Trait. + + An Operation is a pure Python class that is used to model a specific MMA or Copy instruction. + The Trait wraps the underlying IR Value and provides access to the metadata of the instruction + encoded using CuTe Layouts. When the Trait can be constructed straighforwardly from an + Operation, the ``make_mma_atom`` or ``make_copy_atom`` API should be used. There are cases where + constructing the metadata is not trivial and requires more information, for example to determine + the number of bytes copied per TMA instruction ("the TMA vector length"). In such cases, + dedicated helper functions are provided with an appropriate API such that the Atom is + constructed internally in an optimal fashion for the user. + """ + + def __init__(self, op: Op, trait: Trait) -> None: + self._op = op + self._trait = trait + + def __extract_mlir_values__(self): + return extract_mlir_values(self._trait) + + def __new_from_mlir_values__(self, values): + return self.__class__(self.op, new_from_mlir_values(self._trait, values)) + + @property + def op(self) -> Op: + return self._op + + @property + def type(self): + return self._trait.value.type + + @dsl_user_op + def set(self, modifier, value, *, loc=None, ip=None) -> None: + """ + Sets runtime fields of the Atom. + + Some Atoms have runtime state, for example a tcgen05 MMA Atom + + + .. code-block:: python + + tiled_mma = cute.make_tiled_mma(some_tcgen05_mma_op) + tiled_mma.set(cute.nvgpu.tcgen05.Field.ACCUMULATE, True) + + The ``set`` method provides a way to the user to modify such runtime state. Modifiable + fields are provided by arch-specific enumerations, for example ``tcgen05.Field``. The Atom + instance internally validates the field as well as the value provided by the user to set + the field to. + """ + self._trait.set(modifier, value, loc=loc, ip=ip) + + def _unpack(self, *, loc=None, ip=None, **kwargs) -> ir.Value: + return self._trait.unpack(loc=loc, ip=ip, **kwargs) + + +#################################################################################################### +# +# MMA Atoms, TiledMma, and ThrMma +# +#################################################################################################### + + +class MmaAtom(Atom): + """ + The MMA Atom class. + """ + + def __str__(self) -> str: + res = "MMA Atom\n" + res += " ThrID: " + pretty_str(self.thr_id) + "\n" + res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" + res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" + res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" + res += " TV Layout C: " + pretty_str(self.tv_layout_C) + return res + + # + # Properties + # + + @property + def thr_id(self) -> Layout: + return _cute_ir.static(self._trait.value.type.thr_id) + + @property + def shape_mnk(self) -> Shape: + return _unpack_x_tuple(self._trait.value.type.shape_mnk) + + @property + def tv_layout_A(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_a_tv) + + @property + def tv_layout_B(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_b_tv) + + @property + def tv_layout_C(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_c_tv) + + # + # make_fragment + # + + @dsl_user_op + def make_fragment_A(self, input, *, loc=None, ip=None): + # input could be memref/shape/layout for tmem based fragment + if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_A(input, loc=loc, ip=ip) + input = input.value + if isinstance(input, tuple): + input = _pack_shape(input, loc=loc, ip=ip) + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.A, + self._trait.value, + input, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def make_fragment_B(self, input, *, loc=None, ip=None): + if isinstance(input, _Tensor): + if self.op is not None: + self.op._verify_fragment_B(input, loc=loc, ip=ip) + input = input.value + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.B, + self._trait.value, + input, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def make_fragment_C(self, input, *, loc=None, ip=None): + # input could be memref/shape/layout for tmem based fragment + if isinstance(input, _Tensor): + input = input.value + if isinstance(input, tuple): + input = _pack_shape(input, loc=loc, ip=ip) + return _cute_ir.mma_make_fragment( + _cute_ir.MmaOperand.C, + self._trait.value, + input, + loc=loc, + ip=ip, + ) + + +class TiledMma(MmaAtom): + """ + The tiled MMA class. + """ + + def __str__(self) -> str: + res = "Tiled MMA\n" + res += " Thr Layout VMNK: " + pretty_str(self.thr_layout_vmnk) + "\n" + res += " Permutation MNK: " + pretty_str(self.permutation_mnk) + "\n" + res += "MMA Atom\n" + res += " ThrID: " + pretty_str(self.thr_id) + "\n" + res += " Shape MNK: " + pretty_str(self.shape_mnk) + "\n" + res += " TV Layout A: " + pretty_str(self.tv_layout_A) + "\n" + res += " TV Layout B: " + pretty_str(self.tv_layout_B) + "\n" + res += " TV Layout C: " + pretty_str(self.tv_layout_C) + return res + + # + # Properties + # + + @property + def tv_layout_A_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_a_tv_tiled) + + @property + def tv_layout_B_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_b_tv_tiled) + + @property + def tv_layout_C_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_c_tv_tiled) + + @property + def permutation_mnk(self) -> Tile: + return _unpack_x_tuple(self._trait.value.type.permutation_mnk) + + @property + def thr_layout_vmnk(self) -> Layout: + return _cute_ir.static(self._trait.value.type.thr_layout_vmnk) + + @property + def size(self) -> int: + return self._trait.value.type.size + + # + # Tiler + # + + def get_tile_size(self, mode_idx: int) -> Shape: + assert (mode_idx >= 0) and (mode_idx < 3) + perm_tile = self.permutation_mnk[mode_idx] + if perm_tile is None: + thr_layout_vmnk = self.thr_layout_vmnk + atom_shape_mnk = self.shape_mnk + return size(atom_shape_mnk, mode=[mode_idx]) * size( + thr_layout_vmnk, mode=[mode_idx + 1] + ) + else: + return size(perm_tile) + + # + # get_slice + # + + def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrMma": + return ThrMma(self.op, self._trait, thr_idx) + + # + # partition_shape + # + + def _partition_shape(self, operand_id, shape, *, loc=None, ip=None): + shape = _pack_shape(shape, loc=loc, ip=ip) + return _unpack_x_tuple( + _cute_ir.tiled_mma_partition_shape( + operand_id, self._trait.value, shape, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_shape_A(self, shape_mk, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.A, shape_mk, loc=loc, ip=ip) + + @dsl_user_op + def partition_shape_B(self, shape_nk, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.B, shape_nk, loc=loc, ip=ip) + + @dsl_user_op + def partition_shape_C(self, shape_mn, *, loc=None, ip=None): + return self._partition_shape(_cute_ir.MmaOperand.C, shape_mn, loc=loc, ip=ip) + + # + # _thrfrg + # + + @overload + def _thrfrg(self, operand_id, input: Layout, *, loc=None, ip=None) -> Layout: ... + + @overload + def _thrfrg(self, operand_id, input: Tensor, *, loc=None, ip=None) -> Tensor: ... + + def _thrfrg(self, operand_id, input, *, loc=None, ip=None) -> Union[Tensor, Layout]: + if isinstance(input, Tensor): + return make_tensor( + input.iterator, + self._thrfrg(operand_id, input.layout, loc=loc, ip=ip), + ) + elif isinstance(input, Layout): + if not is_static(input.type): + raise ValueError(f"Expects a static layout but got {input.type}") + return _cute_ir.static( + self._trait.value.type.thrfrg(operand_id, input), loc=loc, ip=ip + ) + + raise ValueError( + f"Expects a layout or a tensor as input but got {type(input)=}" + ) + + def _thrfrg_A( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.A, input, loc=loc, ip=ip) + + def _thrfrg_B( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.B, input, loc=loc, ip=ip) + + def _thrfrg_C( + self, input: Union[Layout, Tensor], *, loc=None, ip=None + ) -> Union[Layout, Tensor]: + return self._thrfrg(_cute_ir.MmaOperand.C, input, loc=loc, ip=ip) + + +class ThrMma(TiledMma): + """ + The thread MMA class for modeling a thread-slice of a tiled MMA. + """ + + def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: + super().__init__(op, trait) + self._thr_idx = thr_idx + + def __new_from_mlir_values__(self, values): + return self.__class__( + self.op, new_from_mlir_values(self._trait, values), self.thr_idx + ) + + @property + def thr_idx(self): + return self._thr_idx + + @dsl_user_op + def partition_A(self, input_mk: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.A, + self._trait.value, + input_mk.value, + thr_idx, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_B(self, input_nk: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.B, + self._trait.value, + input_nk.value, + thr_idx, + loc=loc, + ip=ip, + ) + + @dsl_user_op + def partition_C(self, input_mn: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_mma_partition( + _cute_ir.MmaOperand.C, + self._trait.value, + input_mn.value, + thr_idx, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mma_atom(op: MmaOp, *, loc=None, ip=None, **kwargs) -> MmaAtom: + """ + Makes an MMA Atom from an MMA Operation. + + This function creates an MMA Atom from a given MMA Operation. Arbitrary kw arguments can be + provided for Op-specific additional parameters. They are not used as of today. + + :param op: The MMA Operation to construct an Atom for + :type op: MmaOp + :return: The MMA Atom + :rtype: MmaAtom + """ + trait = op._make_trait(loc=loc, ip=ip, **kwargs) + return MmaAtom(op, trait) + + +@dsl_user_op +def make_tiled_mma( + op_or_atom: Union[Op, MmaAtom], + atom_layout_mnk=(1, 1, 1), + permutation_mnk=None, + *, + loc=None, + ip=None, + **kwargs, +) -> TiledMma: + """ + Makes a tiled MMA from an MMA Operation or an MMA Atom. + + :param op_or_atom: The MMA Operation or Atom + :type op_or_atom: Union[Op, MmaAtom] + :param atom_layout_mnk: A Layout describing the tiling of Atom across threads + :type atom_layout_mnk: Layout + :param permutation_mnk: A permutation Tiler describing the tiling of Atom across values including any permutation of such tiling + :type permutation_mnk: Tiler + :return: The resulting tiled MMA + :rtype: TiledMma + """ + if isinstance(op_or_atom, Op): + op = op_or_atom + atom = make_mma_atom(op_or_atom, loc=loc, ip=ip, **kwargs) + elif isinstance(op_or_atom, MmaAtom): + op = op_or_atom.op + atom = op_or_atom + else: + raise TypeError( + f"expected an MMA Op or Atom, but got an instance of {type(op_or_atom)}" + ) + if isinstance(atom_layout_mnk, tuple): + atom_layout_mnk = make_layout(atom_layout_mnk, loc=loc, ip=ip) + if rank(atom_layout_mnk) != 3: + raise ValueError(f"expects rank-3 MNK atom layout, but got {atom_layout_mnk}") + permutation_mnk_ty = None + if permutation_mnk is not None: + permutation_mnk_ty = _pack_tile(permutation_mnk, loc=loc, ip=ip).type + ty = _cute_nvgpu_ir.TiledMmaType.get( + atom._trait.value.type, + atom_layout_mnk.type, + permutation_mnk_ty, + ) + val = _cute_ir.make_tiled_mma(ty, atom._trait.value, loc=loc, ip=ip) + # Instead of modifying atom which might have been provided by the user, create a brand new + # trait instance and replace the Atom ir.Value with the tiled one + trait = new_from_mlir_values(atom._trait, [val]) + return TiledMma(op, trait) + + +#################################################################################################### +# +# Copy Atoms, TiledCopy, and ThrCopy +# +#################################################################################################### + + +class CopyAtom(Atom): + """ + The Copy Atom class. + """ + + def __str__(self) -> str: + res = "Copy Atom\n" + res += " ThrID: " + str(self.thr_id) + "\n" + res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" + res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" + res += " Value type: " + str(self._trait.value.type.value_type) + return res + + # + # Properties + # + + @property + def value_type(self) -> Type[Numeric]: + return Numeric.from_mlir_type(self._trait.value.type.value_type) + + @property + def thr_id(self) -> Layout: + return _cute_ir.static(self._trait.value.type.thr_id) + + @property + def layout_src_tv(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_src_tv) + + @property + def layout_dst_tv(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_dst_tv) + + +class TiledCopy(CopyAtom): + """ + The tiled Copy class. + """ + + def __str__(self) -> str: + res = "Tiled Copy\n" + res += " Tiler MN: " + pretty_str(self.tiler_mn) + "\n" + res += " TV Layout tiled: " + str(self.layout_tv_tiled) + "\n" + res += "Copy Atom\n" + res += " ThrID: " + str(self.thr_id) + "\n" + res += " TV Layout Src: " + str(self.layout_src_tv) + "\n" + res += " TV Layout Dst: " + str(self.layout_dst_tv) + "\n" + res += " Value type: " + str(self._trait.value.type.value_type) + return res + + # + # Properties + # + + @property + def layout_tv_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_tv_tiled) + + @property + def tiler_mn(self) -> Tile: + return _unpack_x_tuple(self._trait.value.type.tiler_mn) + + @property + def layout_src_tv_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_src_tv_tiled) + + @property + def layout_dst_tv_tiled(self) -> Layout: + return _cute_ir.static(self._trait.value.type.layout_dst_tv_tiled) + + @property + def size(self) -> int: + return self._trait.value.type.size + + # + # get_slice and retile + # + + def get_slice(self, thr_idx: Union[int, Int32]) -> "ThrCopy": + return ThrCopy(self.op, self._trait, thr_idx) + + @dsl_user_op + def retile(self, src, *, loc=None, ip=None): + return _cute_ir.tiled_copy_retile( + tiled_copy=self._trait.value, input=src.value, loc=loc, ip=ip + ) + + +class ThrCopy(TiledCopy): + """ + The thread Copy class for modeling a thread-slice of a tiled Copy. + """ + + def __init__(self, op: Op, trait: Trait, thr_idx: Union[int, Int32]) -> None: + super().__init__(op, trait) + self._thr_idx = thr_idx + + def __new_from_mlir_values__(self, values): + return self.__class__( + self.op, new_from_mlir_values(self._trait, values), self.thr_idx + ) + + @property + def thr_idx(self): + return self._thr_idx + + @dsl_user_op + def partition_S(self, src: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_copy_partition_S( + self._trait.value, src.value, thr_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def partition_D(self, dst: Tensor, *, loc=None, ip=None) -> Tensor: + thr_idx = _pack_coord(self.thr_idx, loc=loc, ip=ip) + return _cute_ir.tiled_copy_partition_D( + self._trait.value, dst.value, thr_idx, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_copy_atom( + op: CopyOp, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs +) -> CopyAtom: + """ + Makes a Copy Atom from a Copy Operation. + + This function creates a Copy Atom from a given Copy Operation. Arbitrary kw arguments can be + provided for Op-specific additional parameters. + + Example: + + .. code-block:: python + + op = cute.nvgpu.CopyUniversalOp() + atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) + + :param op: The Copy Operation to construct an Atom for + :type op: CopyOp + :param copy_internal_type: An internal data type used to construct the source/destination layouts in unit of tensor elements + :type copy_internal_type: Type[Numeric] + :return: The Copy Atom + :rtype: CopyAtom + """ + trait = op._make_trait(copy_internal_type, loc=loc, ip=ip, **kwargs) + return CopyAtom(op, trait) + + +@dsl_user_op +def make_layout_tv( + thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None +) -> Tuple[Shape, Layout]: + """Create a thread-value layout for partitioning data tensors. + + This function creates a thread-value layout that maps between ``(thread_idx, value_idx)`` + coordinates and logical ``(M,N)`` coordinates. The thread layout must be compact to ensure + proper partitioning. + + This implements the thread-value partitioning pattern shown in + Figure TVLayout, where data is partitioned across threads and values within each thread. + + :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) + :type thr_layout: Layout + :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs within each thread + :type val_layout: Layout + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tuple containing ``tiler_mn`` and ``layout_tv`` + :rtype: Tuple[Shape, Layout] + + where: + * ``tiler_mn`` is tiler and ``shape(tiler_mn)`` is compatible with ``shape(zipped_divide(x, tiler_mn))[0]`` + * ``layout_tv``: Thread-value layout mapping (thread_idx, value_idx) -> (M,N) + + **Example:** + + .. code-block:: python + + tiler_mn, layout_tv = cute.make_layout_tv( + cute.make_layout((4, 8), stride=(8, 1)), cute.make_layout(2, stride=1) + ) + + Above code creates a TV layout that maps between thread/value coordinates + and the logical coordinates in a 8x8 matrix with: + + * thread block layout ``(4,8):(8,1)`` + * 2 elements per thread + """ + + if not isinstance(thr_layout, Layout): + raise TypeError(f"expected a Layout for thr_layout, but got {type(thr_layout)}") + if not isinstance(val_layout, Layout): + raise TypeError(f"expected a Layout for val_layout, but got {type(val_layout)}") + + # Take the raked_products to compute the Layout_MN + # (M,N) -> (thr_idx, val_idx) + layout_mn = raked_product(thr_layout, val_layout, loc=loc, ip=ip) + thr_size = size(thr_layout, loc=loc, ip=ip) + val_size = size(val_layout, loc=loc, ip=ip) + tmp = make_layout((thr_size, val_size), loc=loc, ip=ip) + # (thr_idx, val_idx) -> (M,N) + layout_tv = composition( + right_inverse(layout_mn, loc=loc, ip=ip), tmp, loc=loc, ip=ip + ) + + tiler_mn = product_each(layout_mn.shape, loc=loc, ip=ip) + + return (tiler_mn, layout_tv) + + +def _make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + if type(tiler_mn) is tuple: + tiler_mn = _pack_tile(tiler_mn, loc=loc, ip=ip) + + assert isinstance(tiler_mn, ir.Value) and _cute_ir.TileType.isinstance( + tiler_mn.type + ), f"tiler_mn must be a Tile, but got {type(tiler_mn)}" + assert is_static(layout_tv.type) and is_static( + tiler_mn.type + ), "layout tv and tiler mn must be static" + tiled_copy_ty = _cute_nvgpu_ir.TiledCopyType.get( + atom.type, layout_tv.type, tiler_mn.type + ) + + val = _cute_ir.make_tiled_copy(tiled_copy_ty, atom._trait.value, loc=loc, ip=ip) + # Instead of modifying atom which might have been provided by the user, create a brand new + # trait instance and replace the Atom ir.Value with the tiled one + trait = new_from_mlir_values(atom._trait, [val]) + return TiledCopy(atom.op, trait) + + +def make_tiled_copy(atom, layout_tv, tiler_mn, *, loc=None, ip=None): + """Create a tiled type given a TV partitioner and tiler. + + :param atom: Copy atom, e.g. smit_copy and simt_async_copy, tma_load, etc. + :type atom: CopyAtom + :param layout_tv: Thread-value layout + :type layout_tv: Layout + :param tiler_mn: Tile size + :type tiler_mn: Tiler + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +@dsl_user_op +def make_tiled_copy_tv( + atom: CopyAtom, thr_layout: Layout, val_layout: Layout, *, loc=None, ip=None +) -> TiledCopy: + """Create a tiled copy given separate thread and value layouts. + + A TV partitioner is inferred based on the input layouts. The input thread layout + must be compact. + + :param atom: Copy atom + :type atom: CopyAtom + :param thr_layout: Layout mapping from ``(TileM,TileN)`` coordinates to thread IDs (must be compact) + :type thr_layout: Layout + :param val_layout: Layout mapping from ``(ValueM,ValueN)`` coordinates to value IDs + :type val_layout: Layout + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + tiler_mn, layout_tv = make_layout_tv(thr_layout, val_layout, loc=loc, ip=ip) + tiler_mn = _pack_tile(product_each(tiler_mn, loc=loc, ip=ip), loc=loc, ip=ip) + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +@dsl_user_op +def make_tiled_copy_A(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the A-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_A_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_B(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the B-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_B_tiled, + (tiled_mma.get_tile_size(1), tiled_mma.get_tile_size(2)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_C(atom, tiled_mma, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the C-Layout of tiled_mma. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_mma: Tiled MMA + :type tiled_mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, + tiled_mma.tv_layout_C_tiled, + (tiled_mma.get_tile_size(0), tiled_mma.get_tile_size(1)), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_tiled_copy_S(atom, tiled_copy, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the Src-Layout of tiled_copy. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, tiled_copy.layout_src_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_tiled_copy_D(atom, tiled_copy, *, loc=None, ip=None): + """Create a tiled copy out of the copy_atom that matches the Dst-Layout of tiled_copy. + + :param atom: Copy atom + :type atom: CopyAtom + :param tiled_copy: Tiled copy + :type tiled_copy: TiledCopy + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for the partitioner + :rtype: TiledCopy + """ + + return _make_tiled_copy( + atom, tiled_copy.layout_dst_tv_tiled, tiled_copy.tiler_mn, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_tiled_copy_C_atom(atom: CopyAtom, mma: TiledMma, *, loc=None, ip=None): + """Create the smallest tiled copy that can retile LayoutC_TV for use with pipelined epilogues with subtiled stores. + + :param atom: Copy atom + :type atom: CopyAtom + :param mma: Tiled MMA + :type mma: TiledMma + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + + :return: A tiled copy for partitioner + :rtype: TiledCopy + + :raises ValueError: If the number value of CopyAtom's source layout is greater than the size of TiledMma's LayoutC_TV + """ + # Truncate the V-layout to just the Copy_Atom, keep the V-order + layoutC_tv = mma.tv_layout_C_tiled + val_layout_src = atom.layout_src_tv + num_val_src = size(val_layout_src, mode=[1], loc=loc, ip=ip) + num_val_layoutC_tv = size(layoutC_tv, mode=[1], loc=loc, ip=ip) + if num_val_src > num_val_layoutC_tv: + raise ValueError( + f"The number value of CopyAtom's source layout {num_val_src} " + f"is greater than the size of TiledMma's LayoutC_TV {num_val_layoutC_tv}" + ) + layout_TV = composition( + layoutC_tv, + make_layout( + (size(layoutC_tv, mode=[0], loc=loc, ip=ip), num_val_src), loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + # Recompute tiler and restride the TV layout for the new tiler + + # Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + # Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + mma_tiler = (mma.get_tile_size(0), mma.get_tile_size(1)) + + tiler_0 = filter( + composition( + make_layout(mma_tiler, stride=(1, 0), loc=loc, ip=ip), + layout_TV, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + tiler_1 = filter( + composition( + make_layout(mma_tiler, stride=(0, 1), loc=loc, ip=ip), + layout_TV, + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + tiler = (tiler_0, tiler_1) + + tile2mma = composition( + make_layout(mma_tiler, loc=loc, ip=ip), tiler, loc=loc, ip=ip + ) + layout_tv = composition( + left_inverse(tile2mma, loc=loc, ip=ip), layout_TV, loc=loc, ip=ip + ) + + tiler_mn = _pack_tile(tiler, loc=loc, ip=ip) + + return _make_tiled_copy(atom, layout_tv, tiler_mn, loc=loc, ip=ip) + + +#################################################################################################### +# +# cute.gemm and cute.copy +# +#################################################################################################### + + +@dsl_user_op +def gemm( + atom: MmaAtom, + d: Tensor, + a: Tensor, + b: Tensor, + c: Tensor, + *, + loc=None, + ip=None, + **kwargs, +) -> None: + """The GEMM algorithm. + + Computes ``D <- A * B + C`` where ``C`` and ``D`` can alias. Note that some MMA Atoms (e.g. + warpgroup-wide or tcgen05 MMAs) require manually setting an "accumulate" boolean field. + + All tensors must be partitioned according to the provided MMA Atom. + + For MMA Atoms that require single-threaded execution, the gemm op automatically handles thread + election internally. Manual thread selection is not required in such cases. + + Following dispatch rules are supported: + + - Dispatch [1]: (V) x (V) => (V) => (V,1,1) x (V,1,1) => (V,1,1) + - Dispatch [2]: (M) x (N) => (M,N) => (1,M,1) x (1,N,1) => (1,M,N) + - Dispatch [3]: (M,K) x (N,K) => (M,N) => (1,M,K) x (1,N,K) => (1,M,N) + - Dispatch [4]: (V,M) x (V,N) => (V,M,N) => (V,M,1) x (V,N,1) => (V,M,N) + - Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) + + :param atom: MMA atom + :type atom: MmaAtom + :param d: Destination tensor + :type d: Tensor + :param a: First source tensor + :type a: Tensor + :param b: Second source tensor + :type b: Tensor + :param c: Third source tensor + :type c: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR, defaults to None + :type ip: Optional[InsertionPoint], optional + :param kwargs: Additional keyword arguments + :type kwargs: dict + :return: None + :rtype: None + """ + + a_rank = rank(a.shape) + b_rank = rank(b.shape) + c_rank = rank(c.shape) + d_rank = rank(d.shape) + + if a_rank != b_rank: + raise ValueError("`a` and `b` must have the same rank") + + if c_rank != d_rank: + raise ValueError("`c` and `d` must have the same rank") + + if a_rank == 1: + if c_rank > 2: + raise ValueError("`c` must have rank <= 2 when `a` has rank 1") + elif a_rank == 2: + if c_rank not in (2, 3): + raise ValueError("`c` must have rank 2 or 3 when `a` has rank 2") + elif a_rank == 3: + if c_rank != 3: + raise ValueError("`c` must have rank 3 when `a` has rank 3") + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + return _cute_ir.gemm(value, d.value, a.value, b.value, c.value, loc=loc, ip=ip) + + +@dsl_user_op +def basic_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """Performs a basic element-wise copy. + + This functions **assumes** the following pre-conditions: + 1. `size(src) == size(dst)` + + When the `src` and `dst` shapes are static, the pre-conditions are actually verified and the + element-wise loop is fully unrolled. + + :param src: Source tensor + :type src: Tensor + :param dst: Destination tensor + :type dst: Tensor + :param loc: Source location for MLIR, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point, defaults to None + :type ip: Optional[InsertionPoint], optional + """ + + if is_static(src.shape) and is_static(dst.shape): + simt_copy_ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + src.element_type.mlir_type, src.element_type.width + ) + simt_copy = _cute_ir.atom(simt_copy_ty, loc=loc, ip=ip) + return _cute_ir.copy(simt_copy, src.value, dst.value, loc=loc, ip=ip) + + s = size(dst, loc=loc, ip=ip) + # Always generate an scf.for Op when one of the tensors is dynamic + for i in for_generate(0, s): + dst[i] = src[i] + yield_out() + + +@dsl_user_op +def basic_copy_if(pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """Performs a basic predicated element-wise copy. + + This functions **assumes** the following pre-conditions: + 1. `size(src) == size(dst)` + 2. `size(src) == size(pred)` + + When all shapes are static, the pre-conditions are actually verified and the element-wise loop + is fully unrolled. + + """ + if src.element_type.width != dst.element_type.width: + raise NotImplementedError( + "basic_copy_if currently only supports equal source and destination " + "element type bit width" + ) + + if is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape): + return _basic_copy_if_static(pred, src, dst, loc=loc, ip=ip) + + s = size(dst, loc=loc, ip=ip) + # Always generate an scf.for Op when one of the tensors is dynamic + for i in for_generate(0, s): + if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) + yield_out() + + +# Version of basic_copy_if when src and dst have static shapes +# - verify size(src) == size(dst) == size(prd) +# - fully unroll the loop for now +def _basic_copy_if_static( + pred: Tensor, src: Tensor, dst: Tensor, *, loc=None, ip=None +) -> None: + assert is_static(src.shape) and is_static(dst.shape) and is_static(pred.shape) + if size(src, loc=loc, ip=ip) != size(dst, loc=loc, ip=ip): + raise ValueError( + "basic_copy expects the size of source, destination, and predicate tensors to match" + ) + # Fully unrolled loop in the static case for now + for i in range(size(dst, loc=loc, ip=ip)): + if_generate(pred[i], lambda: dst.__setitem__(i, src[i])) + + +@dsl_user_op +def autovec_copy(src: Tensor, dst: Tensor, *, loc=None, ip=None) -> None: + """ + Auto-vectorizing SIMT copy policy. + + Given a source and destination tensors that are statically shaped, this policy figures out the + largest safe vector width that the copy instruction can take and performs the copy. + """ + if src.element_type.width != dst.element_type.width: + raise NotImplementedError( + "autovec_copy currently only supports equal source and destination " + "element type bit width" + ) + + # We are going to dispatch to copy-with-atom which requires shapes to be static + if not is_static(src.shape) or not is_static(dst.shape): + raise ValueError( + "autovec_copy expects source and destination tensors to be statically shaped" + ) + + vec_layout = max_common_layout(src, dst, loc=loc, ip=ip) + num_common_elements = size(vec_layout, loc=loc, ip=ip) + + # Next we construct an upper-bound on the number bits that can be vectorized by considering + # - the maximum alignment of the layouts + # - the maximum alignment of the pointers + + upper_bound = math.gcd(src.layout.max_alignment, dst.layout.max_alignment) + upper_bound = math.gcd(upper_bound, num_common_elements) + upper_bound *= src.element_type.width + + # For our instructions, the alignment of the pointer is an upper bound to the vector width + # max_alignment, as opposed to alignment, takes into account possible address swizzling + upper_bound = math.gcd(upper_bound, src.iterator.max_alignment * 8) + upper_bound = math.gcd(upper_bound, dst.iterator.max_alignment * 8) + + # Finally, we put a cap at 128b + num_bits_per_copy = math.gcd(upper_bound, 128) + + if (num_common_elements > 1) and (num_bits_per_copy % 8 == 0): + num_common_elements = num_bits_per_copy // src.element_type.width + + # 2 step logical divides ensuring that the divides are valid at every step + vec_src = logical_divide(src, vec_layout, loc=loc, ip=ip) + vec_dst = logical_divide(dst, vec_layout, loc=loc, ip=ip) + tiled_src = logical_divide( + vec_src, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip + ) + tiled_dst = logical_divide( + vec_dst, make_layout(num_common_elements, loc=loc, ip=ip), loc=loc, ip=ip + ) + + # Dispatch to copy with atom + simt_type = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + src.element_type.mlir_type, num_bits_per_copy + ) + simt_copy = _cute_ir.atom(simt_type, loc=loc, ip=ip) + return _cute_ir.copy( + simt_copy, tiled_src.value, tiled_dst.value, loc=loc, ip=ip + ) + + # Failed to vectorize, use a basic copy + basic_copy(src, dst, loc=loc, ip=ip) + + +@dsl_user_op +def copy( + atom: CopyAtom, + src: Tensor, + dst: Tensor, + *, + pred: Optional[Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + """ + The Copy algorithm. + + The "copy with Atom" expects source and destination tensors to be partitioned according to the + provided Copy Atom. Some Atoms require additional Op-specific kw arguments, for example TMA + copies: + + .. code-block:: python + + cute.copy(tma_atom, src, dst, tma_bar_ptr=mbar_ptr, mcast_mask=mask) + + An additional predication tensor can be provided. If the partitioned tensors have the following + logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile + consistent with ``(ATOM_REST,REST_M,...)``. + + For Copy Atoms that require single-threaded execution, the copy op automatically handles thread + election internally. Manual thread selection is not required in such cases. + """ + if isinstance(src.type, _cute_ir.MemRefType) and isinstance( + dst.type, _cute_ir.MemRefType + ): + if src.element_type.width != dst.element_type.width: + raise TypeError( + "`copy` currently only supports equal source and destination " + "element type bit width" + ) + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + if isinstance(pred, Tensor): + pred = pred.value + return _cute_ir.copy(value, src.value, dst.value, pred=pred, loc=loc, ip=ip) + + +@dsl_user_op +def copy_atom_call( + atom: CopyAtom, + src: Tensor, + dst: Tensor, + *, + pred: Optional[Tensor] = None, + loc=None, + ip=None, + **kwargs, +) -> None: + """ + Execute a single copy atom operation. + + The copy_atom_call operation executes a copy atom with the given operands. + Following src/dst layout of atom are valid: + * ((atom_v)) + * (atom_v) + + Note: The format ((atom_v, rest_v)) is NOT valid for copy_atom_call since it would + require multiple atom operations, which contradicts the definition of a single copy atom call. + + Examples: + + .. code-block:: python + + # Call a copy atom operation + cute.copy_atom_call(copy_atom, src_tensor, dst_tensor) + + An additional predication tensor can be provided. If the partitioned tensors have the following + logical profile ``((ATOM_V,ATOM_REST),REST_M,...)``, the predication tensor must have a profile + consistent with ``(ATOM_REST,REST_M,...)``. + """ + if isinstance(src.type, _cute_ir.MemRefType) and isinstance( + dst.type, _cute_ir.MemRefType + ): + if src.element_type.width != dst.element_type.width: + raise TypeError( + "`copy_atom_call` currently only supports equal source and destination " + "element type bit width" + ) + + value = atom._unpack(loc=loc, ip=ip, **kwargs) + if isinstance(pred, Tensor): + pred = pred.value + return _cute_ir.copy_atom_call( + value, src.value, dst.value, pred=pred, loc=loc, ip=ip + ) + + +def prefetch(atom: CopyAtom, src: Tensor, *, loc=None, ip=None) -> None: + """ + The Prefetch algorithm. + + The "prefetch" expects source tensors to be partitioned according to the provided Copy Atom. + Prefetch is used for loading tensors from global memory to L2. + + Prefetch accepts Copy Atom but not all are allowed. Currently, only support for tma load tensor prefetch. + + .. code-block:: python + + cute.prefetch(tma_atom, src) + + For Copy Atoms that require single-threaded execution, the copy op automatically handles thread + election internally. Manual thread selection is not required in such cases. + """ + dummy_tma_bar_ptr = make_ptr(Int64, 0, AddressSpace.smem, loc=loc, ip=ip) + value = atom._unpack(loc=loc, ip=ip, tma_bar_ptr=dummy_tma_bar_ptr) + return _cute_ir.prefetch(value, src.value, loc=loc, ip=ip) + +#################################################################################################### +# +# TensorSSA class (experimental) +# +#################################################################################################### + + +class ReductionOp(Enum): + ADD = auto() + MUL = auto() + MAX = auto() + MIN = auto() + INC = auto() + DEC = auto() + AND = auto() + OR = auto() + XOR = auto() + + def __str__(self): + return self.name.lower() + + +class TensorSSA(cutlass_arith.ArithValue): + """A class representing thread local data from CuTe Tensor in value semantic and immutable. + + :param value: Flatten vector as ir.Value holding logic data of SSA Tensor + :type value: ir.Value + :param shape: The nested shape in CuTe of the vector + :type shape: Shape + :param dtype: Data type of the tensor elements + :type dtype: Type[Numeric] + + :ivar _shape: The nested shape in CuTe of the vector + :ivar _dtype: Data type of the tensor elements + + :raises ValueError: If shape is not static + """ + + def __init__(self, value, shape: Shape, dtype: Type[Numeric]): + """Initialize a new TensorSSA object. + + :param value: Flatten vector as ir.Value holding logic data of SSA Tensor + :type value: ir.Value + :param shape: The nested shape in CuTe of the vector + :type shape: Shape + :param dtype: Data type of the tensor elements + :type dtype: Type[Numeric] + :raises ValueError: If shape is not static + """ + if not is_static(shape): + raise ValueError("dynamic shape is not supported") + + signed = dtype.signed if issubclass(dtype, Integer) else False + super().__init__(value, signed) + + self._shape = shape + self._dtype = dtype + self._layout = None + + @property + def dtype(self) -> Type[Numeric]: + return self._dtype + + @property + def element_type(self) -> Type[Numeric]: + return self._dtype + + @abstractmethod + def __extract_mlir_values__(self): + return [self] + + @abstractmethod + def __new_from_mlir_values__(self, values): + return TensorSSA(values[0], self.shape, self.dtype) + + def __str__(self): + return f"tensor_value<{self.type} o {self.shape}>" + + @property + def shape(self): + return self._shape + + @overload + def _apply_op(self, op, other: "TensorSSA", flip, *, loc, ip) -> "TensorSSA": ... + + @overload + def _apply_op( + self, op, other: cutlass_arith.ArithValue, flip, *, loc, ip + ) -> "TensorSSA": ... + + @overload + def _apply_op( + self, op, other: Union[int, float, bool], flip, *, loc, ip + ) -> "TensorSSA": ... + + def _apply_op(self, op, other, flip=False, *, loc=None, ip=None): + def get_attr_for_type(ty, value): + if isinstance(ty, ir.IntegerType): + return ir.IntegerAttr.get(ty, value) + elif isinstance(ty, ir.FloatType): + return ir.FloatAttr.get(ty, value) + else: + raise TypeError(f"unsupported type: {ty}") + + # Canonicalize into Numeric + if isinstance(other, (int, float, bool)) or ( + not isinstance(other, TensorSSA) + and isinstance(other, cutlass_arith.ArithValue) + ): + other = as_numeric(other) + + # Promote types + lhs, rhs, res_type = _binary_op_type_promote(self, other) + + # Promote scalar to vector + if not isinstance(rhs, TensorSSA): + if isinstance(rhs, Numeric): + vect_val = vector.broadcast(lhs.type, rhs.ir_value(loc=loc, ip=ip)) + else: + elem_attr = get_attr_for_type(lhs.type.element_type, rhs) + vect_attr = ir.DenseElementsAttr.get_splat(lhs.type, elem_attr) + vect_val = arith.constant(lhs.type, vect_attr, loc=loc, ip=ip) + rhs = TensorSSA(vect_val, lhs.shape, lhs.dtype) + + if flip: + lhs, rhs = rhs, lhs + + if op in ( + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne, + ): + res_type = Boolean + + assert isinstance(rhs, TensorSSA), f"rhs must be TensorSSA but got {rhs}" + + def _broadcast(s, t): + if s == 1: + return t + elif t == 1: + return s + elif s == t: + return s + else: + raise ValueError(f"cannot broadcast {s} and {t}") + + max_rank = max(rank(lhs.shape), rank(rhs.shape)) + lhs_shape = append(lhs.shape, 1, up_to_rank=max_rank) + rhs_shape = append(rhs.shape, 1, up_to_rank=max_rank) + res_shape = transform_leaf(_broadcast, lhs_shape, rhs_shape) + + # broadcast to the same shape + lhs = lhs.broadcast_to(res_shape) + rhs = rhs.broadcast_to(res_shape) + + if ( + op in (operator.add, operator.sub) + and lhs.dtype == Boolean + and rhs.dtype == Boolean + ): + res = op(lhs.to(Int32), rhs.to(Int32)) + zero = zeros_like(res) + res = res.__ne__(zero).to(res_type) + else: + lhs_val = lhs.maybe_downcast() + rhs_val = rhs.maybe_downcast() + + if issubclass(lhs.dtype, Integer): + lhs_val = lhs_val.with_signedness(lhs.dtype.signed) + + if issubclass(rhs.dtype, Integer): + rhs_val = rhs_val.with_signedness(rhs.dtype.signed) + + res_vect = op(lhs_val, rhs_val) + res = TensorSSA(res_vect, lhs._shape, res_type) + + return res + + def broadcast_to(self, target_shape: Shape, *, loc=None, ip=None) -> "TensorSSA": + """ + Broadcast the tensor to the target shape. + """ + # pad source shape to the same rank + shape = append(self.shape, 1, up_to_rank=rank(target_shape)) + if shape == target_shape: + return self + + def _check_broadcast(s, t): + if s != t and s != 1: + raise ValueError( + f"src_shape and target_shape must be the same when src_shape is not 1, but got {s} and {t}" + ) + + transform_leaf(_check_broadcast, shape, target_shape) + + # reshape to flatten N-D vector + flat_shp = flatten_to_tuple(shape) + temp_ty = ir.VectorType.get(list(flat_shp), self.dtype.mlir_type) + temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) + + # broadcast to result N-D vector + flat_tgt_shp = flatten_to_tuple(target_shape) + temp_tgt_ty = ir.VectorType.get(list(flat_tgt_shp), self.dtype.mlir_type) + temp_tgt_vect = vector.broadcast(temp_tgt_ty, temp_vect, loc=loc, ip=ip) + + res_1d_ty = ir.VectorType.get([size(target_shape)], self.dtype.mlir_type) # type: ignore + res_1d_vect = vector.shape_cast(res_1d_ty, temp_tgt_vect, loc=loc, ip=ip) + + return TensorSSA(res_1d_vect, target_shape, self.dtype) + + def __pow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the results of tensor^other. + + :param other: The other tensor for exponent. + :type other: TensorSSA + :return: The power of the tensor. + :rtype: TensorSSA + """ + return self._apply_op(operator.pow, other, loc=loc, ip=ip) + + def __rpow__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the results of other^tensor. + + :param other: The other tensor to compute power with. + :type other: TensorSSA + :return: The element-wise power of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.pow, other, flip=True, loc=loc, ip=ip) + + def __add__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the sum of the tensor and another tensor. + + :param other: The other tensor to add. + :type other: TensorSSA + :return: The sum of the two tensors with the same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.add, other, loc=loc, ip=ip) + + def __radd__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the sum of the tensor and another tensor (reverse add) + + :param other: The other tensor to add. + :type other: TensorSSA + :return: The sum of the two tensors with the same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.add, other, flip=True, loc=loc, ip=ip) + + def __sub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the difference of the tensor and another tensor. + + :param other: The other tensor to subtract. + :type other: TensorSSA + :return: The subtraction of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.sub, other, loc=loc, ip=ip) + + def __rsub__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the difference of the tensor and another tensor (reverse subtract) + + :param other: The other tensor to subtract. + :type other: TensorSSA + :return: The subtraction of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.sub, other, flip=True, loc=loc, ip=ip) + + def __mul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the multiplication of the tensor and another tensor. + + :param other: The other tensor to multiply. + :type other: TensorSSA + :return: The multiplication of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mul, other, loc=loc, ip=ip) + + def __rmul__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the multiplication of the tensor and another tensor (reverse multiply) + + :param other: The other tensor to multiply. + :type other: TensorSSA + :return: The multiplication of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mul, other, flip=True, loc=loc, ip=ip) + + def __mod__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the modulo of the tensor and another tensor. + + :param other: The other tensor to compute modulo with. + :type other: TensorSSA + :return: The element-wise modulo of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mod, other, loc=loc, ip=ip) + + def __rmod__(self, other) -> "TensorSSA": + """ + Returns the modulo of the tensor and another tensor (reverse modulo) + + :param other: The other tensor to compute modulo with. + :type other: TensorSSA + :return: The element-wise modulo of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.mod, other, flip=True) + + def __floordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the floordiv(//) of the tensor and another tensor. + + :param other: The other tensor to compute floordiv with. + :type other: TensorSSA + :return: The floordiv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.floordiv, other, loc=loc, ip=ip) + + def __rfloordiv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the floordiv(//) of the tensor and another tensor (reverse floordiv) + + :param other: The other tensor to compute floordiv with. + :type other: TensorSSA + :return: The floordiv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.floordiv, other, flip=True, loc=loc, ip=ip) + + def __truediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the truediv(/) of the tensor and another tensor. + + :param other: The other tensor to compute truediv with. + :type other: TensorSSA + :return: The truediv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.truediv, other, loc=loc, ip=ip) + + def __rtruediv__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the truediv(/) of the tensor and another tensor (reverse truediv) + + :param other: The other tensor to compute truediv with. + :type other: TensorSSA + :return: The truediv of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.truediv, other, flip=True, loc=loc, ip=ip) + + def __eq__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the comparison of the tensor and another tensor as mask + + :param other: The other tensor to compare. + :type other: TensorSSA + :return: The comparison of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.eq, other, loc=loc, ip=ip) + + def __ne__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise not equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self != other. + :rtype: TensorSSA + """ + return self._apply_op(operator.ne, other, loc=loc, ip=ip) + + def __lt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise less than comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self < other. + :rtype: TensorSSA + """ + return self._apply_op(operator.lt, other, loc=loc, ip=ip) + + def __le__(self, other) -> "TensorSSA": + """ + Returns the element-wise less than or equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self <= other. + :rtype: TensorSSA + """ + return self._apply_op(operator.le, other) + + def __gt__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise greater than comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self > other. + :rtype: TensorSSA + """ + return self._apply_op(operator.gt, other) + + def __ge__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise greater than or equal comparison of the tensor and another tensor. + + :param other: The other tensor to compare with. + :type other: TensorSSA + :return: A boolean tensor with same shape as inputs, True where self >= other. + :rtype: TensorSSA + """ + return self._apply_op(operator.ge, other, loc=loc, ip=ip) + + def __xor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise XOR of the tensor and another tensor. + + :param other: The other tensor to perform XOR with. + :type other: TensorSSA + :return: The element-wise XOR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.xor, other) + + def __rxor__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the bitwise XOR of the tensor and another tensor. + + :param other: The other tensor to compute XOR with. + :type other: TensorSSA + :return: The element-wise bitwise XOR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.xor, other, flip=True, loc=loc, ip=ip) + + def __or__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise OR of the tensor and another tensor. + + :param other: The other tensor to perform OR with. + :type other: TensorSSA + :return: The element-wise OR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.or_, other) + + def __ror__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise OR of the tensor and another tensor. + + :param other: The other tensor to perform OR with. + :type other: TensorSSA + :return: The element-wise OR of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.or_, other, flip=True) + + def __and__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise AND of the tensor and another tensor. + + :param other: The other tensor to perform AND with. + :type other: TensorSSA + :return: The element-wise AND of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.and_, other) + + def __rand__(self, other, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the element-wise AND of the tensor and another tensor. + + :param other: The other tensor to perform AND with. + :type other: TensorSSA + :return: The element-wise AND of two tensors with same shape as inputs. + :rtype: TensorSSA + """ + return self._apply_op(operator.and_, other, flip=True, loc=loc, ip=ip) + + def __neg__(self, *, loc=None, ip=None) -> "TensorSSA": + """ + Returns the negation of the tensor. + + :return: The element-wise negation of the tensor + :rtype: TensorSSA + """ + + return self._apply_op(operator.sub, 0, flip=True, loc=loc, ip=ip) + + def _flatten_shape_and_coord(self, crd, *, loc=None, ip=None): + # Coalesce and flatten source layout at terminal of coordinate + # (N_0,(N_1,...), ...) -> (N_0,N_1,N_2,...) + crd_shp = product_like(self._shape, target_profile=crd, loc=loc, ip=ip) + + # Flatten coordinate + flat_shp = flatten(crd_shp) + assert isinstance(flat_shp, tuple) and is_static(flat_shp) + # (C_0,(C_1,...), ...) -> (C_0,C_1,C_2,...) + flat_crd = flatten(crd) + + assert isinstance(flat_crd, tuple) and is_static(flat_crd) + return flat_shp, flat_crd + + def _build_result(self, res_vect, res_shp, *, loc=None, ip=None): + if isinstance(res_shp, ir.Value): + raise ValueError( + f"expects static shape and coordinates, but got {self._shape} and {crd}" + ) + + # cast back to 1D vector + res_1d_ty = ir.VectorType.get([size(res_shp)], self.type.element_type) + res_1d_vect = vector.shape_cast(res_1d_ty, res_vect, loc=loc, ip=ip) + return TensorSSA(res_1d_vect, res_shp, self.dtype) + + @dsl_user_op + def __getitem__( + self, crd: Coord, *, loc=None, ip=None + ) -> Union["TensorSSA", Numeric]: + """Access or slice tensor elements using coordinates. + + This method implements tensor evaluation T(c) = *(E + L(c)) where E is the iterator/engine + and L is the layout. It supports both direct element access and slicing operations. + + :param crd: Coordinate or slice specification for accessing tensor elements + :type crd: Coord + :param loc: Source location for MLIR operation tracking, defaults to None + :type loc: Optional[Location] + :param ip: Insertion point for MLIR operation, defaults to None + :type ip: Optional[InsertionPoint] + :return: Tensor element value or sliced subtensor + :rtype: Union[TensorSSA, Numeric] + + :raises ValueError: If coordinate access is invalid for the tensor layout + + **Examples:** + + .. code-block:: python + + # Create a fragment from rmem as shape (8, 4) + layout = make_layout((8, 4)) + tensor = make_fragment(layout, Float32) + frg = tensor.load() + + # Direct element access + val = frg[0] # Returns first element of fragment + val = frg[(0, 1)] # Returns element at (0, 1) + + # Slice access + sliced = frg[(3, None)] # Returns fragment slice + """ + # short-cut to no-op + if crd is None: + return self + + if not has_underscore(crd): + if self._layout is None: + self._layout = make_layout(self._shape, loc=loc, ip=ip) + idx = crd2idx(crd, self._layout, loc=loc, ip=ip) + idx_val = as_numeric(idx).ir_value(loc=loc, ip=ip) + res_val = vector.extractelement(self, position=idx_val, loc=loc, ip=ip) + return self.dtype(res_val) + + if not is_static(crd): + raise ValueError("dynamic coordinate is not supported") + + flat_shp, flat_crd = self._flatten_shape_and_coord(crd) + + multi_dim_ty = ir.VectorType.get(list(flat_shp), self.type.element_type) + # vector -> vector + tmp_vect = vector.shape_cast(multi_dim_ty, self) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self._shape, crd) + if isinstance(res_shp, ir.Value): + raise TypeError( + f"expects static shape and coordinates, but got {self._shape} and {crd}" + ) + + # Offsets is index of coordinates if NOT `_` otherwise 0 + offsets = [c if c is not None else 0 for c in flat_crd] + # Sizes is size of shapes if `_` otherwise 1 + sizes = [s if c is None else 1 for s, c in zip(flat_shp, flat_crd)] + # Logic stride to index vector. Only support stride-1 by vector + strides = [1] * rank(flat_shp) + + # Vector slice on N-D vector + res_ty = ir.VectorType.get(list(sizes), self.type.element_type) + res_vect = vector.extract_strided_slice( + res_ty, tmp_vect, offsets=offsets, sizes=sizes, strides=strides + ) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self._shape, crd) + return self._build_result(res_vect, res_shp, loc=loc, ip=ip) + + @dsl_user_op + def to(self, dtype: Type[Numeric], *, loc=None, ip=None): + """Convert the tensor to a different numeric type. + + :param dtype: The target numeric type to cast to. + :type dtype: Type[Numeric] + :return: A new tensor with the same shape but with elements cast to the target type. + :rtype: TensorSSA + :raises TypeError: If dtype is not a subclass of Numeric. + :raises NotImplementedError: If dtype is an unsigned integer type. + """ + if dtype is ir.Value: + return self + + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a type of Numeric, but got {type(dtype)}") + + src_dtype = self.dtype + if src_dtype == dtype: + return self + + # maybe downcast can lose signedness + src = self.maybe_downcast().with_signedness(self.signed) + if src_dtype.is_float and dtype.is_float: + res_vect = cutlass_arith.cvtf(src, dtype.mlir_type, loc=loc, ip=ip) + elif src_dtype.is_float and issubclass(dtype, Integer): + res_vect = cutlass_arith.fptoi( + src, dtype.signed, dtype.mlir_type, loc=loc, ip=ip + ) + elif issubclass(src_dtype, Integer) and dtype.is_float: + res_vect = cutlass_arith.itofp( + src, src_dtype.signed, dtype.mlir_type, loc=loc, ip=ip + ) + else: + res_vect = cutlass_arith.int_to_int(src, dtype, loc=loc, ip=ip) + + return TensorSSA(res_vect, self._shape, dtype) + + def ir_value(self, *, loc=None, ip=None): + return self + + def ir_value_int8(self, *, loc=None, ip=None): + """ + Returns int8 ir value of Boolean tensor. + When we need to store Boolean tensor ssa, use ir_value_int8(). + + :param loc: Source location information, defaults to None + :type loc: Optional[Location], optional + :param ip: Insertion point for MLIR operations, defaults to None + :type ip: Optional[InsertionPoint], optional + :return: The int8 value of this Boolean + :rtype: ir.Value + """ + assert ( + self.element_type is Boolean + ), f"Only boolean type needs to be converted to int8, got {self.element_type}" + + if not hasattr(self, "_value_int8"): + self._value_int8 = arith.extsi( + T.vector(self.type.shape[0], T.i8()), self, loc=loc, ip=ip + ) + return self._value_int8 + + def reduce(self, op, init_val, reduction_profile: Coord, *, loc=None, ip=None): + """ + Perform reduce on selected modes with given predefined reduction op. + + :param op: The reduction operator to use (operator.add or operator.mul) + :type op: operator + :param init_val: The initial value for the reduction + :type init_val: numeric + :param reduction_profile: Specifies which dimensions to reduce. Dimensions marked with `None` are kept. + :type reduction_profile: Coord + + :return: The reduced tensor + :rtype: TensorSSA + + **Examples:** + + .. code-block:: python + + reduce(f32 o (4,)) + => f32 + + reduce(f32 o (4, 5)) + => f32 + reduce(f32 o (4, (5, 4)), reduction_profile=(None, 1)) + => f32 o (4,) + reduce(f32 o (4, (5, 4)), reduction_profile=(None, (None, 1))) + => f32 o (4, (5,)) + """ + # short-cut to no-op + if reduction_profile is None: + return self + + if not is_weakly_congruent(reduction_profile, self.shape): + raise ValueError( + f"Expect reduction_profile be weakly congruent to the shape of the tensor, " + f"but got {reduction_profile} and {self.shape}" + ) + + if op is ReductionOp.ADD: + red_kind = vector.CombiningKind.ADD + elif op is ReductionOp.MUL: + red_kind = vector.CombiningKind.MUL + elif op is ReductionOp.MAX: + red_kind = vector.CombiningKind.MAXIMUMF + elif op is ReductionOp.MIN: + red_kind = vector.CombiningKind.MINIMUMF + else: + raise NotImplementedError( + f"{op} is not supported, expects one of " + f"{ReductionOp.ADD, ReductionOp.MUL, ReductionOp.MAX, ReductionOp.MIN}" + ) + + elem_ty = self.element_type + # Canonicalize to `Numeric` and convert into MLIR value + init_val = as_numeric(init_val).ir_value(loc=loc, ip=ip) + + if depth(reduction_profile) == 0: + return vector.reduction( + elem_ty.mlir_type, red_kind, self, acc=init_val, loc=loc, ip=ip + ) + + flat_shp, flat_prof = self._flatten_shape_and_coord( + reduction_profile, loc=loc, ip=ip + ) + assert depth(flat_shp) == 1 and depth(flat_prof) == 1 + assert rank(flat_shp) == rank(flat_prof) + + temp_ty = ir.VectorType.get(list(flat_shp), elem_ty.mlir_type) + temp_vect = vector.shape_cast(temp_ty, self, loc=loc, ip=ip) + + if isinstance(flat_prof, tuple): + red_dims = [i for i, x in enumerate(flat_prof) if x is not None] + else: + red_dims = [0] + + temp_acc_shp = slice_(flat_shp, flat_prof, loc=loc, ip=ip) + temp_acc_ty = ir.VectorType.get(list(temp_acc_shp), elem_ty.mlir_type) + + init_val = vector.broadcast(temp_acc_ty, init_val, loc=loc, ip=ip) + res_vect = vector.multi_reduction( + red_kind, temp_vect, acc=init_val, reduction_dims=red_dims, loc=loc, ip=ip + ) + + # Slice and keep dims matching `_` or None + res_shp = slice_(self.shape, reduction_profile, loc=loc, ip=ip) + return self._build_result(res_vect, res_shp, loc=loc, ip=ip) + + +@dsl_user_op +def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> TensorSSA: + """ + Return a new TensorSSA of given shape and type, filled with fill_value. + + :param shape: Shape of the new tensor. + :type shape: tuple + :param fill_value: Value to fill the tensor with. + :type fill_value: scalar + :param dtype: Data type of the tensor. + :type dtype: Type[Numeric] + :return: Tensor of fill_value with the specified shape and dtype. + :rtype: TensorSSA + """ + size = product(shape, loc=loc, ip=ip) + if not is_static(size): + raise ValueError("shape must be static") + + if isinstance(fill_value, (ir.Value, int, float, bool)): + fill_value = dtype(fill_value) + elif isinstance(fill_value, Numeric): + fill_value = fill_value.to(dtype, loc=loc, ip=ip) + else: + raise ValueError(f"Expected fill_value be numeric type, but got {fill_value}") + + res_ty = T.vector(size, dtype.mlir_type) + res_val = vector.splat(res_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return TensorSSA(res_val, shape, dtype) + + +def full_like( + a: Union[TensorSSA, Tensor], + fill_value, + dtype: Union[None, Type[Numeric]] = None, + *, + loc=None, + ip=None, +) -> TensorSSA: + """ + Return a full TensorSSA with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: array_like + :param fill_value: Fill value. + :type fill_value: array_like + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Union[None, Type[Numeric]], optional + :return: Tensor of `fill_value` with the same shape and type as `a`. + :rtype: TensorSSA + + .. seealso:: + :func:`empty_like`: Return an empty array with shape and type of input. + :func:`ones_like`: Return an array of ones with shape and type of input. + :func:`zeros_like`: Return an array of zeros with shape and type of input. + :func:`full`: Return a new array of given shape filled with value. + + **Examples:** + + .. code-block:: python + + frg = cute.make_fragment(Float32, (2, 3)) + a = frg.load() + b = cute.full_like(a, 1.0) + """ + if not hasattr(a, "shape"): + raise TypeError(f"Expect `a` be shaped type, but got {type(a)}") + + return full( + a.shape, fill_value, dtype if dtype is not None else a.dtype, loc=loc, ip=ip + ) + + +def empty_like(a, dtype=None): + """ + Return a new TensorSSA with the same shape and type as a given array, without initializing entries. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Uninitialized tensor with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 0, dtype) + + +def ones_like(a, dtype=None): + """ + Return a TensorSSA of ones with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Tensor of ones with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 1, dtype) + + +def zeros_like(a, dtype=None, *, loc=None, ip=None): + """ + Return a TensorSSA of zeros with the same shape and type as a given array. + + :param a: The shape and data-type of `a` define these same attributes of the returned array. + :type a: TensorSSA + :param dtype: Overrides the data type of the result, defaults to None + :type dtype: Type[Numeric], optional + :return: Tensor of zeros with the same shape and type (unless overridden) as `a`. + :rtype: TensorSSA + """ + return full_like(a, 0, dtype, loc=loc, ip=ip) + + +def where( + cond: TensorSSA, x: TensorSSA, y: TensorSSA, *, loc=None, ip=None +) -> TensorSSA: + """ + Return elements chosen from x or y depending on condition. + + :param cond: Where True, yield x, where False, yield y. + :type cond: TensorSSA + :param x: Values from which to choose when condition is True. + :type x: TensorSSA + :param y: Values from which to choose when condition is False. + :type y: TensorSSA + :return: A tensor with elements from x where condition is True, and elements from y where condition is False. + :rtype: TensorSSA + """ + if x.dtype != y.dtype: + raise ValueError( + f"x and y must have the same dtype, but got {x.dtype} and {y.dtype}" + ) + + if cond.dtype != Boolean: + raise ValueError(f"cond must be Boolean type, but got {cond.dtype}") + + return TensorSSA( + arith.select(cond.ir_value(), x, y, loc=loc, ip=ip), x.shape, x.dtype + ) + + +def any_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: + """ + Test whether any tensor element evaluates to True. + + :param x: Input tensor. + :type x: TensorSSA + :return: Returns a TensorSSA scalar containing True if any element of x is True, False otherwise. + :rtype: TensorSSA + """ + is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) + return Boolean( + vector.reduction(T.bool(), vector.CombiningKind.OR, is_true, loc=loc, ip=ip) + ) + + +def all_(x: TensorSSA, *, loc=None, ip=None) -> Boolean: + """ + Test whether all tensor elements evaluate to True. + + :param x: Input tensor. + :type x: TensorSSA + :return: Returns a TensorSSA scalar containing True if all elements of x are True, False otherwise. + :rtype: TensorSSA + """ + is_true = x != full_like(x, 0, x.dtype, loc=loc, ip=ip) + return Boolean( + vector.reduction(T.bool(), vector.CombiningKind.AND, is_true, loc=loc, ip=ip) + ) + + +############################################################################## +# User defined struct +############################################################################## + + +class struct: + """ + Decorator to abstract C structure in Python DSL. + + **Usage:** + + .. code-block:: python + + # Supports base_dsl scalar int/float elements, array and nested struct: + @cute.struct + class complex: + real : cutlass.Float32 + imag : cutlass.Float32 + + + @cute.struct + class StorageA: + mbarA : cute.struct.MemRange[cutlass.Int64, stage] + compA : complex + intA : cutlass.Int16 + + + # Supports aligment for its elements: + @cute.struct + class StorageB: + a: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, size_a], 1024 + ] + b: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, size_b], 1024 + ] + x: cute.struct.Align[cutlass.Int32, 16] + compA: cute.struct.Align[complex, 16] + + + # Statically get size and alignment: + size = StorageB.__sizeof__() + align = StorageB.__alignof__() + + # Allocate and referencing elements: + storage = allocator.allocate(StorageB) + + storage.a[0] ... + storage.x ... + storage.compA.real ... + + :param cls: The struct class with annotations. + :return: The decorated struct class. + """ + + # inner class for defining a continuous memory region + class _MemRangeMeta(type): + """ + A metaclass for creating MemRange classes. + + This metaclass is used to dynamically create MemRange classes with specific + data types and sizes. + + :ivar _dtype: The data type of the MemRange. + :ivar _size: The size of the MemRange. + """ + + _dtype = None + _size = None + + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + return new_cls + + def __getitem__(cls, params) -> Type["struct.MemRange"]: + # get params from syntax: struct.MemRange[dtype, size] + if len(params) == 2: + dtype, size = params + else: + raise TypeError("Invalid struct.MemRange Arguments") + + if not struct._is_scalar_type(dtype): + raise TypeError("MemRange only support dsl scalar type!") + + # Create new class with proper name and parameters + new_cls = type( + f"struct.MemRange[{dtype.__name__}, {size}]", + (struct.MemRange,), + {"_dtype": dtype, "_size": size}, + ) + return new_cls + + @property + def size(cls): + return cls._size + + @property + def elem_width(cls): + return cls._dtype.width + + @property + def size_in_bytes(cls): + return cls.size * cls.elem_width // 8 + + class MemRange(metaclass=_MemRangeMeta): + """ + Defines a range of memory by `MemRange[T, size]`. + """ + + pass + + class _MemRangeData: + """ + Represents a range of memory. + + :param dtype: The data type. + :param size: The size of the memory range in bytes. + :param base: The base address of the memory range. + """ + + def __init__(self, dtype, size, base): + """ + Initializes a new memory range. + + :param dtype: The data type. + :param size: Size of the memory range in bytes. A size of **0** is accepted, but in that + case the range can only be used for its address (e.g. as a partition marker). + :param base: The base address of the memory range. + """ + self._dtype = dtype + self._size = size + self._base = base + + def data_ptr(self): + """ + Returns start pointer to the data in this memory range. + + :return: A pointer to the start of the memory range. + :raises AssertionError: If the size of the memory range is negative. + """ + assert self._size >= 0 + return recast_ptr(self._base, dtype=self._dtype) + + def get_tensor(self, layout, swizzle=None, dtype=None): + """ + Creates a tensor from the memory range. + + :param layout: The layout of the tensor. + :param swizzle: Optional swizzle pattern. + :param dtype: Optional data type; defaults to the memory range's data type if not specified. + :return: A tensor representing the memory range. + :raises TypeError: If the layout is incompatible with the swizzle. + :raises AssertionError: If the size of the memory range is not greater than zero. + """ + assert self._size > 0 + # make tensor + if isinstance(layout, ComposedLayout) and (swizzle is not None): + raise TypeError(f"incompatible layout with swizzle") + elem_type = self._dtype if dtype is None else dtype + ptr = recast_ptr(self._base, swizzle, dtype=elem_type) + res = make_tensor(ptr, layout) + return res + + def __getitem__(self, index: int) -> Any: + """ + Returns the element at the specified index in the memory range. + + :param index: The index of the element to retrieve. + :return: The element at the specified index. + :raises AssertionError: If the index is out of range. + """ + assert (index >= 0) and (index < self._size) + return self.data_ptr() + index + + # inner class for aligning a member type + class _AlignMeta(type): + """ + Aligns the given object by setting its alignment attribute. + + :param v: The object to align. Must be a struct, MemRange, or a scalar type. + :param align: The alignment value to set. + :raises TypeError: If the object is not a struct, MemRange, or a scalar type. + + :ivar _dtype: The data type to be aligned. + :ivar _align: The alignment of the data type. + """ + + _dtype = None + _align = None + + def __new__(cls, name, bases, dct): + return super().__new__(cls, name, bases, dct) + + def __getitem__(cls, params) -> Any: + if len(params) == 2: + dtype, align = params + assert align > 0 + else: + raise TypeError("Invalid struct.Align Arguments") + + if not struct._is_scalar_type(dtype) and not isinstance( + dtype, (struct, struct._MemRangeMeta) + ): + raise TypeError( + "align only can be applied to struct/MemRange/base_dsl scalar" + ) + + # Create new class with alignment + new_cls = type( + f"struct.Align[{dtype.__name__}, {align}]", + (struct.Align,), + {"_dtype": dtype, "_align": align}, + ) + return new_cls + + @property + def dtype(cls): + return cls._dtype + + @property + def align(cls): + return cls._align + + class Align(metaclass=_AlignMeta): + """ + Aligns the given type by `Align[T, alignment]`. + """ + + pass + + # util func for base dsl scalar types + @staticmethod + def _is_scalar_type(dtype): + """ + Checks if the given type is a scalar numeric type. + + :param dtype: The type to check. + :return: True if the type is a subclass of Numeric, False otherwise. + """ + return isinstance(dtype, type) and issubclass(dtype, Numeric) + + # calculate size and alignment + def __init__(self, cls): + """ + Initializes a new struct decorator instance. + + :param cls: The class representing the structured data type. + :raises TypeError: If the struct is empty. + """ + self._cls = cls + self.__name__ = f"struct::{cls.__name__}" + # Get the class annotations + self._annotations = cls.__annotations__ + # Create a dictionary to store the offsets + self._offsets: Dict[str, int] = {} + + # Calculate the offsets and alignment + offset = 0 + alignment = 1 + if len(self._annotations) == 0: + raise TypeError("Empty struct is not supported!") + for name, object in self._annotations.items(): + # get alignment of object + sub_align = 1 + if isinstance(object, struct._AlignMeta): + sub_align = object.align + object = object.dtype + + # switch addition order to support dynamic size + def add_offset(val): + return val + offset if isinstance(val, ir.Value) else offset + val + + # size of scalar + if struct._is_scalar_type(object): + dtype_size = max(1, object.width // 8) + sub_align = max(dtype_size, sub_align) + offset = self.align_offset(offset, sub_align) + self._offsets[name] = offset + offset = add_offset(dtype_size) + # size of array is size_in_bytes, alignment is elem_size + elif isinstance(object, struct._MemRangeMeta): + # Allow empty array as a free marker-only struct member. + # Use max(sub_align, ) because we might have in the future some + # object.elem_width less than 8, such as fp4, bit and others, + # and align_offset() does not support an alignment of 0. + sub_align = max(object.elem_width // 8, sub_align) + offset = self.align_offset(offset, sub_align) + self._offsets[name] = offset + offset = add_offset(object.size_in_bytes) + # size of struct + elif isinstance(object, struct): + sub_align = max(object.__alignof__(), sub_align) + offset = self.align_offset(offset, sub_align) + self._offsets[name] = offset + offset = add_offset(object.__sizeof__()) + else: + raise TypeError( + f"Struct element only support struct/array/base_dsl scalar, " + f"but got {object}" + ) + # Total aligment determined by the strictest requirement + alignment = max(alignment, sub_align) + # Total size determined by alignment + self._align_of = alignment + self._size_of = self.align_offset(offset, alignment) + + # create the __init__ method for decorated struct + def __call__(self, base: Any) -> None: + """ + Creates a new instance of the decorated struct. + + :param base: The base address of the struct. + :return: An instance of the decorated struct. + :raises TypeError: If the base pointer is not byte-sized. + """ + if base.type.value_type.width != 8: + raise TypeError("struct base ptr value type must be byte sized.") + # make an new object of user-defined decorated struct + # otherwise it will override same self._cls when new instance created + cls = self._cls() + setattr(cls, "_base", base) + for name, off in self._offsets.items(): + obj = self._annotations[name] + if isinstance(obj, struct._AlignMeta): + obj = obj.dtype + if struct._is_scalar_type(obj): + new_obj = recast_ptr(base + off, dtype=obj) + setattr(cls, name, new_obj) + elif isinstance(obj, struct._MemRangeMeta): + new_obj = struct._MemRangeData(obj._dtype, obj._size, base + off) + setattr(cls, name, new_obj) + elif isinstance(obj, struct): + new_obj = obj(base + off) + setattr(cls, name, new_obj) + else: + raise TypeError( + f"Struct element only support struct/array/base_dsl scalar, " + f"but got {obj}" + ) + return cls + + # get size + def size_in_bytes(self) -> int: + """ + Returns the size of the struct in bytes. + + :return: The size of the struct. + """ + return self._size_of + + # get size + def __sizeof__(self) -> int: + return self._size_of + + # get alignment + def __alignof__(self) -> int: + return self._align_of + + # util func for aligning offset + @staticmethod + def align_offset(offset, align): + """ + Return the round-up offset up to the next multiple of align. + """ + assert align > 0 and not ( + align & (align - 1) + ), "align should be a strictly positive power of 2." + return (offset + (align - 1)) & ~(align - 1) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py new file mode 100644 index 0000000000000000000000000000000000000000..daaa608262d00268ec1c47dfe32758c555f009b0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/math.py @@ -0,0 +1,445 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .core import TensorSSA +from .typing import Numeric +from cutlass._mlir.dialects import math, arith + +from typing import Callable, Union + + +def _math_op(func: Callable, fastmath: bool, *args, **kwargs): + """Dispatch the function to either a TensorSSA or a Numeric(Float). + + :param func: The function to dispatch + :param args: The input tensor or scalar + :param kwargs: The input tensor or scalar + """ + arg_type = type(args[0]) + for arg in args: + if not isinstance(arg, TensorSSA) and ( + not isinstance(arg, Numeric) or not type(arg).is_float + ): + raise TypeError( + f"Expected a TensorSSA or Numeric(Float), but got {type(arg)}" + ) + if not isinstance(arg, arg_type): + raise TypeError( + f"Expected all inputs to be of type {arg_type}, but got {type(arg)}" + ) + + fastmath_flag = arith.FastMathFlags.fast if fastmath else arith.FastMathFlags.none + if isinstance(args[0], TensorSSA): + return TensorSSA( + func(*args, fastmath=fastmath_flag), args[0].shape, args[0].dtype + ) + else: + args = [a.ir_value() for a in args] + return func(*args, fastmath=fastmath_flag) + + +def acos( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc cosine of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc cosine of each element in input tensor + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = acos(y) # Compute arc cosine + """ + return _math_op(math.acos, fastmath, a) + + +def asin( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc sine of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc sine of each element in input tensor + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = asin(y) # Compute arc sine + """ + return _math_op(math.asin, fastmath, a) + + +def atan( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc tangent of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc tangent of each element in input tensor + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = atan(y) # Compute arc tangent + """ + raise NotImplementedError("atan is not implemented") + return _math_op(math.atan, fastmath, a) + + +def atan2( + a: Union[TensorSSA, Numeric], b: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise arc tangent of two tensors. + + Computes atan2(a, b) element-wise. The function atan2(a, b) is the angle in radians + between the positive x-axis and the point given by the coordinates (b, a). + + :param a: First input tensor (y-coordinates) + :type a: Union[TensorSSA, Numeric] + :param b: Second input tensor (x-coordinates) + :type b: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the arc tangent of a/b element-wise + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + y = cute.make_fragment(ptr1, layout).load() # y coordinates + x = cute.make_fragment(ptr2, layout).load() # x coordinates + theta = atan2(y, x) # Compute angles + """ + return _math_op(math.atan2, fastmath, a, b) + + +def cos( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise cosine of the input tensor. + + :param a: Input tensor (in radians) + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the cosine of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = cos(y) # Compute cosine + """ + return _math_op(math.cos, fastmath, a) + + +def erf( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise error function of the input tensor. + + The error function is defined as: + erf(x) = 2/√π ∫[0 to x] exp(-t²) dt + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the error function value for each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = erf(y) # Compute error function + """ + return _math_op(math.erf, fastmath, a) + + +def exp( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise exponential of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the exponential of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = exp(y) # Compute exponential + """ + return _math_op(math.exp, fastmath, a) + + +def exp2( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise base-2 exponential of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing 2 raised to the power of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = exp2(y) # Compute 2^x + """ + return _math_op(math.exp2, fastmath, a) + + +def log( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise natural logarithm of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the natural logarithm of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = log(y) # Compute natural logarithm + """ + return _math_op(math.log, fastmath, a) + + +def log2( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise base-2 logarithm of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the base-2 logarithm of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = log2(y) # Compute log base 2 + """ + return _math_op(math.log2, fastmath, a) + + +def log10( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise base-10 logarithm of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the base-10 logarithm of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = log10(y) # Compute log base 10 + """ + return _math_op(math.log10, fastmath, a) + + +def rsqrt( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise reciprocal square root of the input tensor. + + Computes 1/√x element-wise. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the reciprocal square root of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = rsqrt(y) # Compute 1/√x + """ + return _math_op(math.rsqrt, fastmath, a) + + +def sin( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise sine of the input tensor. + + :param a: Input tensor (in radians) + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the sine of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = sin(y) # Compute sine + """ + return _math_op(math.sin, fastmath, a) + + +def sqrt( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise square root of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the square root of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = sqrt(y) # Compute square root + """ + return _math_op(math.sqrt, fastmath, a) + + +def tan( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise tangent of the input tensor. + + :param a: Input tensor (in radians) + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the tangent of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = tan(y) # Compute tangent + """ + return _math_op(math.tan, fastmath, a) + + +def tanh( + a: Union[TensorSSA, Numeric], fastmath: bool = False +) -> Union[TensorSSA, Numeric]: + """Compute element-wise hyperbolic tangent of the input tensor. + + :param a: Input tensor + :type a: Union[TensorSSA, Numeric] + :param fastmath: Enable fast math optimizations, defaults to False + :type fastmath: bool, optional + :return: Tensor containing the hyperbolic tangent of each element + :rtype: Union[TensorSSA, Numeric] + + Example: + + .. code-block:: + + x = cute.make_fragment(layout) # Create tensor + y = x.load() # Load values + z = tanh(y) # Compute hyperbolic tangent + """ + return _math_op(math.tanh, fastmath, a) + + +__all__ = [ + "acos", + "asin", + "atan", + "atan2", + "cos", + "erf", + "exp", + "exp2", + "log", + "log10", + "log2", + "rsqrt", + "sin", + "sqrt", + "tan", + "tanh", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0655bb09c05ae84714656020127cb41a4f28fbf6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/__init__.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from . import warp +from . import cpasync +from . import warpgroup +from . import tcgen05 + +from .common import * +from .helpers import * + + +# __all__ is required here for documentation generation +__all__ = [ + "OpError", + "MmaUniversalOp", + "CopyUniversalOp", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0c4c82debcd55cd7f3d7df0e21920cda83ca18 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/common.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. +import enum +from dataclasses import dataclass +from typing import Type, Optional + +from cutlass.cutlass_dsl import DSLBaseError + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from .. import core +from ..typing import Float16, Float32, Float64, Numeric + + +class OpError(DSLBaseError): + """ + An exception class for Op construction errors. + """ + + def __init__( + self, op: core.Op, message: str, suggestion: Optional[str] = None + ) -> None: + if suggestion is None: + # Default suggestion + suggestion = "Check your Op construction code" + super().__init__( + message, + error_code=f"{op.__class__.__name__} error", + suggestion=suggestion, + ) + + +#################################################################################################### +# +# MMA Ops and Traits +# +#################################################################################################### + + +@dataclass(frozen=True) +class MmaUniversalOp(core.MmaOp): + """ + The universal MMA Operation. + + This Operation currently expects the A/B operands as well as the accumulator to share the same + data types. + + :param abacc_dtype: The data type for the A/B operands and the accumulator + :type abacc_dtype: Type[Numeric] + """ + + abacc_dtype: Type[Numeric] + + def __post_init__(self) -> None: + if self.abacc_dtype not in [Float16, Float32, Float64]: + raise OpError( + self, + f"expects the 'abacc_dtype' Op parameter to be one of Float16, Float32, or Float64", + ) + + def __str__(self) -> str: + return ( + "universal MMA Operation using FMA" + f"\n A/B/Accumulator data type = {self.abacc_dtype}" + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaUniversalTrait": + shape_mnk_attr = ir.Attribute.parse(f'#cute.shape<"(1,1,1)">') + atom_ty = _cute_nvgpu_ir.UniversalFmaAtomType.get( + shape_mnk_attr, + self.abacc_dtype.mlir_type, + self.abacc_dtype.mlir_type, + self.abacc_dtype.mlir_type, + ) + return MmaUniversalTrait(_cute_ir.atom(atom_ty, loc=loc, ip=ip)) + + def _verify_fragment_A(self, input, *, loc=None, ip=None): + pass + + def _verify_fragment_B(self, input, *, loc=None, ip=None): + pass + +class MmaUniversalTrait(core.Trait): + pass + + +#################################################################################################### +# +# Copy Ops and Traits +# +#################################################################################################### + + +class MemoryOrder(enum.Enum): + WEAK = _cute_ir.MemOrderKind.WEAK + RELAXED = _cute_ir.MemOrderKind.RELAXED + ACQUIRE = _cute_ir.MemOrderKind.ACQUIRE + RELEASE = _cute_ir.MemOrderKind.RELEASE + ACQ_REL = _cute_ir.MemOrderKind.ACQ_REL + SC = _cute_ir.MemOrderKind.SC + MMIO = _cute_ir.MemOrderKind.MMIO + CONSTANT = _cute_ir.MemOrderKind.CONSTANT + VOLATILE = _cute_ir.MemOrderKind.VOLATILE + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MemOrderKind: + return self.value + + +class MemoryScope(enum.Enum): + CTA = _cute_ir.MemScopeKind.CTA + CLUSTER = _cute_ir.MemScopeKind.CLUSTER + GPU = _cute_ir.MemScopeKind.GPU + SYS = _cute_ir.MemScopeKind.SYS + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MemScopeKind: + return self.value + +@dataclass(frozen=True) +class CopyUniversalOp(core.CopyOp): + """ + The universal Copy Operation. + + When creating a Copy Atom out of this operation, the expected usage pattern is + + .. code-block:: python + + op = cute.nvgpu.CopyUniversalOp() + atom = cute.make_copy_atom(op, tensor_dtype, num_bits_per_copy=64) + + - ``tensor_dtype`` is the data type used to build the reference TV Layout (either the source \ + or the destination TV Layout) in unit of tensor elements and is used for partitioning by \ + ``TiledCopy`` for example + - ``num_bits_per_copy`` is a kw argument specifying the number of bits to copy per Atom \ + execution. This can be larger than the width of the above data type. When not provided, \ + the compiler will do a best effort at auto-vectorizing. + """ + + def __str__(self) -> str: + return "universal Copy Operation" + + def _make_trait( + self, + copy_internal_type: Type[Numeric], + *, + loc=None, + ip=None, + **kwargs, + ) -> "CopyUniversalTrait": + num_bits_per_copy = kwargs.get("num_bits_per_copy", 0) + memory_order = kwargs.get("memory_order", MemoryOrder.WEAK) + memory_scope = kwargs.get("memory_scope", MemoryScope.CTA) + if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy < 0): + raise ValueError( + "expects a 'num_bits_per_copy' kw argument of type int that is non-negative " + f"when creating a copy Atom for {self.__class__.__name__}" + ) + ty = _cute_nvgpu_ir.CopyAtomSIMTSyncCopyType.get( + copy_internal_type.mlir_type, + num_bits_per_copy, + memory_order._to_ir(), + memory_scope._to_ir(), + ) + return CopyUniversalTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class CopyUniversalTrait(core.Trait): + pass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..246360c2eb43ed5c4ca45127c579bc9f496caa08 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .copy import * +from .helpers import * + + +# __all__ is required here for documentation generation +__all__ = [ + # + # copy.py + # + "LoadCacheMode", + "CopyG2SOp", + "CopyBulkTensorTileG2SOp", + "CopyBulkTensorTileG2SMulticastOp", + "CopyBulkTensorTileS2GOp", + "CopyReduceBulkTensorTileS2GOp", + # + # helpers.py + # + "make_tiled_tma_atom", + "tma_partition", + "create_tma_multicast_mask", + "prefetch_descriptor", + "copy_tensormap", + "update_tma_descriptor", + "fence_tma_desc_acquire", + "cp_fence_tma_desc_release", + "fence_tma_desc_release", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..a15495602304700d19803825d93004e0fa9fc509 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/copy.py @@ -0,0 +1,471 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Optional, Type + +from cutlass.cutlass_dsl import CuTeDSL, t + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ...core import CopyOp, Trait, ReductionOp +from ...typing import Int16, Pointer, Integer, Numeric +from ..common import OpError +from ..tcgen05.mma import CtaGroup + + +#################################################################################################### +# +# Aynchronous copies +# +#################################################################################################### + + +class LoadCacheMode(enum.Enum): + """ + An enumeration for the possible cache modes of a non-bulk ``cp.async`` instruction. + + See the `PTX documentation `__. + """ + + ALWAYS = _cute_nvgpu_ir.LoadCacheMode.always + GLOBAL = _cute_nvgpu_ir.LoadCacheMode.global_ + STREAMING = _cute_nvgpu_ir.LoadCacheMode.streaming + LAST_USE = _cute_nvgpu_ir.LoadCacheMode.last_use + NONE = _cute_nvgpu_ir.LoadCacheMode.none + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_nvgpu_ir.LoadCacheMode: + return self.value + + +@dataclass(frozen=True) +class CopyG2SOp(CopyOp): + """ + Non-bulk asynchronous GMEM to SMEM Copy Operation. + + See the `PTX documentation `__. + """ + + cache_mode: LoadCacheMode = LoadCacheMode.ALWAYS + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM copy Operation" + if self.cache_mode != LoadCacheMode.ALWAYS: + res += f"\n with cache mode = {self.cache_mode}" + return res + + def _make_trait( + self, + copy_internal_type: Type[t.Numeric], + *, + loc=None, + ip=None, + **kwargs, + ) -> "CopyG2STrait": + num_bits_per_copy = kwargs.get("num_bits_per_copy", None) + # Verify that the user provided enum values + if not isinstance(self.cache_mode, LoadCacheMode): + raise OpError( + self, + "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", + ) + if not isinstance(num_bits_per_copy, int) or (num_bits_per_copy <= 0): + raise ValueError( + "expects a 'num_bits_per_copy' kw argument of type int that is positive " + f"when creating a copy Atom for {self.__class__.__name__}" + ) + # Verify that the user provided enum values + if not isinstance(self.cache_mode, LoadCacheMode): + raise OpError( + self, + "expects the 'cache_mode' Op parameter to be a LoadCacheMode instance", + ) + ty = _cute_nvgpu_ir.CopyAtomSIMTAsyncCopyType.get( + copy_internal_type.mlir_type, self.cache_mode._to_ir(), num_bits_per_copy + ) + return CopyG2STrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class CopyG2STrait(Trait): + pass + + +#################################################################################################### +# +# Bulk tensor copies a.k.a TMA copies +# +#################################################################################################### + +TMA_MBAR_PTR_FIELD_NAME = "tma_bar" +TMA_MASK_FIELD_NAME = "mcast_mask" +TMA_DESC_PTR_FIELD_NAME = "tma_descriptor_ptr" + +# +# TMA GMEM -> SMEM copies +# + + +@dataclass(frozen=True) +class CopyBulkTensorTileG2SOp(CopyOp): + """ + Bulk tensor asynchrnous GMEM to SMEM Copy Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + cta_group: CtaGroup = CtaGroup.ONE + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, "expects the 'cta_group' parameter to be a CtaGroup instance" + ) + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": + raise OpError( + self, + f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM bulk tensor copy Operation" + if self.cta_group == CtaGroup.TWO: + res += f"\n CTA group = 2" + return res + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkTensorTileG2SNonExecTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: + if self.cta_group == CtaGroup.ONE: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90 + elif self.cta_group == CtaGroup.TWO: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm + else: + assert False, "unrecognized self.cta_group" + + +class CopyBulkTensorTileG2SNonExecTrait(Trait): + # We allow kw args to be dropped so that the user can write common code for non-multicast + # and multicast loads. + def unpack( + self, + *, + loc=None, + ip=None, + tma_bar_ptr: Optional[Pointer] = None, + tma_desc_ptr: Optional[Pointer] = None, + **kwargs, + ): + """ + Custom implementation of unpack for non-executable TMAs. + + The non-multicast TMA load requires a `tma_bar_ptr` keyword argument to be provided when + using `cute.copy`. Any other kw arguments will be ignored instead of triggering an error. + """ + if not isinstance(tma_bar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" + ) + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_MBAR_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip + ) + if isinstance(tma_desc_ptr, Pointer): + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + + +# +# TMA GMEM -> SMEM multicast copies +# + + +@dataclass(frozen=True) +class CopyBulkTensorTileG2SMulticastOp(CopyOp): + """ + Bulk tensor asynchrnous multicast GMEM to SMEM Copy Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + cta_group: CtaGroup = CtaGroup.ONE + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post_init__(self): + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, "expects the 'cta_group' parameter to be a CtaGroup instance" + ) + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + if (self.cta_group == CtaGroup.TWO) and arch[:5] == "sm_90": + raise OpError( + self, + f"CTA group of 2 is tcgen05-specific and is not and is not compatible with {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + res = "cp.async GMEM -> SMEM bulk tensor multicast copy Operation" + if self.cta_group == CtaGroup.TWO: + res += f"\n CTA group = 2" + return res + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkTensorTileG2SMulticastNonExecTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.TiledTmaLoadEnum: + if self.cta_group == CtaGroup.ONE: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_90_multicast + elif self.cta_group == CtaGroup.TWO: + return _cute_nvgpu_ir.TiledTmaLoadEnum.sm_100_2sm_multicast + else: + assert False, "unrecognized self.cta_group" + + +class CopyBulkTensorTileG2SMulticastNonExecTrait(Trait): + def unpack( + self, + *, + loc=None, + ip=None, + tma_bar_ptr: Optional[Pointer] = None, + mcast_mask=None, + tma_desc_ptr=None, + ): + """ + Custom implementation of unpack for non-executable TMAs. + + The multicast TMA load requires a `tma_bar_ptr` and a `mcast_mask` keyword arguments to be + provided when using `cute.copy`. + """ + if not isinstance(tma_bar_ptr, Pointer): + raise ValueError( + "expects a pointer to an mbarrier to be provided via the tma_bar_ptr kw argument" + ) + if not isinstance(mcast_mask, Integer): + raise ValueError( + "expects a multicast mask to be provided via the mcast_mask kw argument" + ) + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_bar_ptr.value, loc=loc, ip=ip + ) + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, Int16(mcast_mask).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + if isinstance(tma_desc_ptr, Pointer): + attr_str = f"#cute_nvgpu.atom_copy_field_tmaload<{TMA_DESC_PTR_FIELD_NAME}>" + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + + +# +# TMA SMEM -> GMEM copies +# + + +@dataclass(frozen=True) +class CopyBulkTensorTileS2GOp(CopyOp): + """ + Bulk tensor asynchronous SMEM to GMEM Copy Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post_init__(self): + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + return "cp.async SMEM -> GMEM bulk tensor copy Operation" + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyBulkTensorTileS2GTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + +class CopyBulkTensorTileS2GTrait(Trait): + def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None): + """ + Custom implementation of unpack for non-executable TMAs. + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + if isinstance(tma_desc_ptr, Pointer): + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmastore<{TMA_DESC_PTR_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + +@dataclass(frozen=True) +class CopyReduceBulkTensorTileS2GOp(CopyOp): + """ + Bulk tensor asynchronous SMEM to GMEM Reduction Operation using the TMA unit. + + See the `PTX documentation `__. + This Operation uses TMA in the ``.tile`` mode. + """ + + reduction_kind: ReductionOp = ReductionOp.ADD + + admissible_archs = [ + "sm_90", + "sm_90a", + "sm_100a", + "sm_100f", + ] + + def __post__init__(self): + # Arch verification + arch = CuTeDSL.__get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + def __str__(self) -> str: + return "cp.async SMEM -> GMEM bulk tensor reduction Operation" + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "CopyReduceBulkTensorTileS2GTrait": + raise NotImplementedError( + "Use cpasync.make_tiled_tma_atom to obtain a copy Atom for TMA" + ) + + def _to_ir(self) -> _cute_nvgpu_ir.ReductionKind: + if self.reduction_kind == ReductionOp.ADD: + return _cute_nvgpu_ir.ReductionKind.ADD + elif self.reduction_kind == ReductionOp.MIN: + return _cute_nvgpu_ir.ReductionKind.MIN + elif self.reduction_kind == ReductionOp.MAX: + return _cute_nvgpu_ir.ReductionKind.MAX + elif self.reduction_kind == ReductionOp.INC: + return _cute_nvgpu_ir.ReductionKind.INC + elif self.reduction_kind == ReductionOp.DEC: + return _cute_nvgpu_ir.ReductionKind.DEC + elif self.reduction_kind == ReductionOp.AND: + return _cute_nvgpu_ir.ReductionKind.AND + elif self.reduction_kind == ReductionOp.OR: + return _cute_nvgpu_ir.ReductionKind.OR + elif self.reduction_kind == ReductionOp.XOR: + return _cute_nvgpu_ir.ReductionKind.XOR + else: + assert False, "unrecognized self.reduction_kind" + + +class CopyReduceBulkTensorTileS2GTrait(Trait): + def unpack(self, *, loc=None, ip=None, tma_desc_ptr: Optional[Pointer] = None): + """ + Custom implementation of unpack for non-executable TMAs. + """ + exec_value = _cute_nvgpu_ir.atom_make_exec_tma(self.value, loc=loc, ip=ip) + if isinstance(tma_desc_ptr, Pointer): + attr_str = ( + f"#cute_nvgpu.atom_copy_field_tmareduce<{TMA_DESC_PTR_FIELD_NAME}>" + ) + attr = ir.Attribute.parse(attr_str) + exec_value = _cute_nvgpu_ir.atom_set_value( + exec_value, attr, tma_desc_ptr.value, loc=loc, ip=ip + ) + return exec_value + +__all__ = [ + "LoadCacheMode", + "CopyG2SOp", + "CopyBulkTensorTileG2SOp", + "CopyBulkTensorTileG2SMulticastOp", + "CopyBulkTensorTileS2GOp", + "CopyReduceBulkTensorTileS2GOp", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f64f07f167501d1805096373e915017612de4387 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py @@ -0,0 +1,341 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Tuple, Type, Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import llvm + +from ...typing import Coord, Layout, Tensor, Tiler, Pointer, Int16, Numeric, NumericMeta +from ... import core +from .copy import ( + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileS2GOp, + CopyReduceBulkTensorTileS2GOp, + CopyBulkTensorTileG2SNonExecTrait, + CopyBulkTensorTileG2SMulticastNonExecTrait, + CopyBulkTensorTileS2GTrait, + CopyReduceBulkTensorTileS2GTrait, +) + + +@dsl_user_op +def make_tiled_tma_atom( + op: Union[ + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileS2GOp, + CopyReduceBulkTensorTileS2GOp, + ], + gmem_tensor: Tensor, + smem_layout: Union[Layout, core.ComposedLayout], + cta_tiler: Tiler, + num_multicast: int = 1, + *, + internal_type: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy Atom in the ``.tile`` mode to copy tiles of a GMEM tensor to/from SMEM + buffer with the given Layout. + + Given + + - a GMEM tensor + - a SMEM layout + - a CTA-level Tiler + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to/from an SMEM buffer with the provided + layout and consistent with the provided Tiler. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates \ + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the \ + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned \ + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileS2GOp, CopyReduceBulkTensorTileS2GOp] + :param gmem_tensor: The GMEM tensor involved in the Copy + :type gmem_tensor: Tensor + :param smem_layout: The SMEM layout to construct the Copy Atom for + :type smem_layout: Union[Layout, core.ComposedLayout] + :param cta_tiler: The CTA Tiler to use + :type cta_tiler: Tiler + :param num_multicast: The multicast factor + :type num_multicast: int + :param internal_type: An optional parameter for the internal data type to use when the actual data type is not supported by the TMA unit + :type internal_type: Type[Numeric] + :return: A Copy Atom for this Operation and the associated TMA tensor + :rtype: Tuple[core.CopyAtom, Tensor] + """ + + if internal_type is not None: + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + internal_type = internal_type.mlir_type + + cta_v_map = core.composition( + core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip), + cta_tiler, + loc=loc, + ip=ip, + ) + + if isinstance(op, CopyBulkTensorTileG2SOp): + if num_multicast != 1: + raise ValueError( + f"expects num_multicast to be 1 for non multicast G2S copies, " + f"but got {num_multicast}" + ) + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + elif isinstance(op, CopyBulkTensorTileG2SMulticastOp): + if num_multicast < 1: + raise ValueError( + f"expects num_multicast to be >= 1 for multicast G2S copies, " + f"but got {num_multicast}" + ) + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return ( + core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + ) + elif isinstance(op, CopyBulkTensorTileS2GOp): + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_store( + gmem_tensor.value, + smem_layout, + cta_v_map, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return core.CopyAtom(op, CopyBulkTensorTileS2GTrait(res[0])), res[1] + elif isinstance(op, CopyReduceBulkTensorTileS2GOp): + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_reduce( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + internal_type=internal_type, + loc=loc, + ip=ip, + ) + return core.CopyAtom(op, CopyReduceBulkTensorTileS2GTrait(res[0])), res[1] + else: + raise ValueError(f"expects a bulk tensor (TMA) Copy Op, but got {op}") + + +@dsl_user_op +def tma_partition( + atom: core.CopyAtom, + cta_coord: Coord, + cta_layout: Layout, + smem_tensor: Tensor, + gmem_tensor: Tensor, + *, + loc=None, + ip=None, +) -> Tuple[Tensor, Tensor]: + """ + Tiles the GMEM and SMEM tensors for the provided TMA Copy Atom. + """ + cta_coord_val = core._pack_coord(cta_coord, loc=loc, ip=ip) + s, d = _cute_nvgpu_ir.atom_tma_partition( + atom._trait.value, + cta_coord=cta_coord_val, + cta_layout=cta_layout, + smem_tensor=smem_tensor.value, + gmem_tensor=gmem_tensor.value, + loc=loc, + ip=ip, + ) + return s, d + + +@dsl_user_op +def create_tma_multicast_mask( + cta_layout_vmnk: Layout, + cta_coord_vmnk: Coord, + mcast_mode: int, + *, + loc=None, + ip=None, +) -> Int16: + """ + Computes a multicast mask for a TMA load Copy. + + :param cta_layout_vmnk: The VMNK layout of the cluster + :type cta_layout_vmnk: Layout + :param cta_coord_vmnk: The VMNK coordinate of the current CTA + :type cta_coord_vmnk: Coord + :param mcast_mode: The tensor mode in which to multicast + :type mcast_mode: int + :return: The resulting mask + :rtype: Int16 + """ + if core.rank(cta_layout_vmnk) != 4: + raise ValueError( + f"cta_layout_vmnk must be rank 4, but got {core.pretty_str(cta_layout_vmnk)}" + ) + if core.rank(cta_coord_vmnk) != 4: + raise ValueError( + f"cta_coord_vmnk must be rank 4, but got {core.pretty_str(cta_coord_vmnk)}" + ) + return core.make_layout_image_mask( + cta_layout_vmnk, cta_coord_vmnk, mcast_mode, loc=loc, ip=ip + ) + + +@dsl_user_op +def prefetch_descriptor(tma_atom: core.CopyAtom, *, loc=None, ip=None) -> None: + """ + Prefetches the TMA descriptor associated with the TMA Atom. + """ + _cute_nvgpu_ir.prefetch_tma_desc(tma_atom._trait.value, loc=loc, ip=ip) + + +@dsl_user_op +def copy_tensormap( + tma_atom: core.CopyAtom, tensormap_ptr: Pointer, *, loc=None, ip=None +) -> None: + """ + Copies the tensormap held by a TMA Copy Atom to the memory location pointed to by the provided + pointer. + + :param tma_atom: The TMA Copy Atom + :type tma_atom: CopyAtom + :param tensormap_ptr: The pointer to the memory location to copy the tensormap to + :type tensormap_ptr: Pointer + """ + _cute_nvgpu_ir.copy_tma_desc( + tma_atom._trait.value, tensormap_ptr.value, loc=loc, ip=ip + ) + + +@dsl_user_op +def update_tma_descriptor( + tma_atom: core.CopyAtom, + gmem_tensor: Tensor, + tma_desc_ptr: Pointer, + *, + loc=None, + ip=None, +) -> None: + """ + Updates the TMA descriptor in the memory location pointed to by the provided pointer using + information from a TMA Copy Atom and the provided GMEM tensor. + + Specifically, the following fields of the TMA descriptor will be updated: + + 1. the GMEM tensor base address + 2. the GMEM tensor shape + 3. the GMEM tensor stride + + Other fields of the TMA descriptor are left unchanged. + + :param tma_atom: The TMA Copy Atom + :type tma_atom: CopyAtom + :param gmem_tensor: The GMEM tensor + :type gmem_tensor: Tensor + :param tensormap_ptr: The pointer to the memory location of the descriptor to udpate + :type tensormap_ptr: Pointer + """ + _cute_nvgpu_ir.update_tma_desc( + tma_atom._trait.value, gmem_tensor.value, tma_desc_ptr.value, loc=loc, ip=ip + ) + + +@dsl_user_op +def fence_tma_desc_acquire( + tma_desc_ptr: Pointer, + *, + loc=None, + ip=None, +) -> None: + """ + See the `PTX documentation `__. + """ + tma_desc_ptr_i64 = tma_desc_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [tma_desc_ptr_i64], + "fence.proxy.tensormap::generic.acquire.gpu [$0], 128;", + "l", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def cp_fence_tma_desc_release( + tma_desc_global_ptr: Pointer, + tma_desc_shared_ptr: Pointer, + *, + loc=None, + ip=None, +) -> None: + """ + See the `PTX documentation `__. + """ + tma_desc_global_ptr_i64 = tma_desc_global_ptr.toint(loc=loc, ip=ip).ir_value() + tma_desc_shared_ptr_i32 = tma_desc_shared_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + None, + [tma_desc_global_ptr_i64, tma_desc_shared_ptr_i32], + "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [$0], [$1], 128;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def fence_tma_desc_release(*, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + llvm.inline_asm( + None, + [], + "fence.proxy.tensormap::generic.release.gpu;", + "", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4aa0dbb207dfad2832ddf7a80504c7cf591ff1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/helpers.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Optional, Tuple, Type, Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +from .. import core +from ..typing import Shape, Layout, Tensor, Numeric, NumericMeta +from ...impl_utils import check_type_in +from .cpasync.copy import ( + CopyBulkTensorTileG2SOp, + CopyBulkTensorTileG2SNonExecTrait, + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileG2SMulticastNonExecTrait, +) + + +#################################################################################################### +# +# TMA creation helpers for tcgen05 MMAs +# +#################################################################################################### + + +@dsl_user_op +def make_tiled_tma_atom_A( + op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + gmem_tensor: Tensor, + smem_layout: Union[Layout, core.ComposedLayout], + mma_tiler_mnk: Shape, + tiled_mma: core.TiledMma, + cluster_shape_vmnk: Shape, + *, + internal_type: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation + accounting for the MK projections of the TiledMMA for A tensor loads. + + Given + + - a GMEM tensor + - a SMEM layout + - a MMA Tiler + - a TiledMma + - a Cluster-level shape + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided + layout and consistent with the provided Tiler & tiled_mma (considering the M-mode & K-mode). + The Cluster-level shape is used to determine the multicast factor across the N-mode for A tensor loads. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] + :param gmem_tensor: The GMEM tensor to be loaded by this copy atom + :type gmem_tensor: Tensor + :param smem_layout: Shared memory layout to load the tensor into (PDSL) + :type smem_layout: Union[Layout, core.ComposedLayout] + :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions + :type mma_tiler_mnk: Shape + :param tiled_mma: The TiledMMA that will consume the load as operands + :type tiled_mma: core.TiledMma + :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions + :type cluster_shape_vmnk: Shape + :param internal_type: An optional parameter for the internal data type to when element + type does not match the copy type + :type internal_type: Type[Numeric] + :return: A copy atom for this operation and the associated TMA coord tensor + :rtype: Tuple[core.CopyAtom, Tensor] + + """ + + if internal_type is not None: + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + internal_type = internal_type.mlir_type + check_type_in( + op, + [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + "op", + "make_tiled_tma_atom_A", + ) + + ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) + mma_tiler_mk = (mma_tiler_mnk[0], *mma_tiler_mnk[2:]) + g_tile = core.composition(ident, mma_tiler_mk, loc=loc, ip=ip) + cta_v_map = tiled_mma._thrfrg_A(g_tile) + cta_v_map = core.get(cta_v_map, mode=[1]) + cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) + + if isinstance(op, CopyBulkTensorTileG2SOp): + num_multicast = 1 + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + # multicast across the N-mode since those would share the same tile of A + num_multicast = core.size(cluster_shape_vmnk, mode=[2]) + + # res[0] = the IR Value for the non-executable atom instance + # res[1] = the IR Value for the associated TMA tensor + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + if isinstance(op, CopyBulkTensorTileG2SOp): + return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + return ( + core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + ) + + +@dsl_user_op +def make_tiled_tma_atom_B( + op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + gmem_tensor: Tensor, + smem_layout: Union[Layout, core.ComposedLayout], + mma_tiler_mnk: Shape, + tiled_mma: core.TiledMma, + cluster_shape_vmnk: Shape, + *, + internal_type: Optional[Type[Numeric]] = None, + loc=None, + ip=None, +) -> Tuple[core.CopyAtom, Tensor]: + """ + Makes a TMA Copy atom mapping to ``.tile`` mode for ``cp.async.bulk.tensor`` PTX operation + accounting for the NK projections of the TiledMMA for B tensor loads. + + Given + + - a GMEM tensor + - a SMEM layout + - a MMA Tiler + - a TiledMma + - a Cluster-level shape + + this function figures out the bulk tensor asynchronous copy instruction to use with the maximum + "TMA vector length" to copy tiles of the GMEM tensor to an SMEM buffer with the provided + layout and consistent with the provided Tiler & tiled_mma (considering the N-mode & K-mode). + The Cluster-level shape is used to determine the multicast factor across the M-mode for B tensor loads. + + This function returns two results: + + 1. the Copy Atom + 2. the so-called TMA tensor used to map logical coordinates of the GMEM tensor to coordinates + that the TMA unit can consume. TMA tensors have so-called basis stride elements so that the + associated layout can output coordinates. Otherwise, TMA tensors can be partitioned + similarly to any other CuTe tensors using the algebra. + + :param op: The Copy Operation to construct an Atom for + :type op: Union[CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp] + :param gmem_tensor: The GMEM tensor to be loaded by this copy atom + :type gmem_tensor: Tensor + :param smem_layout: Shared memory layout to load the tensor into (PDSL) + :type smem_layout: Union[Layout, core.ComposedLayout] + :param mma_tiler_mnk: The MMA Tiler shape (TILE_M, TILE_N, TILE_K) in MNK dimensions + :type mma_tiler_mnk: Shape + :param tiled_mma: The TiledMMA that will consume the load as operands + :type tiled_mma: core.TiledMma + :param cluster_shape_vmnk: The Cluster-level shape in VMNK dimensions + :type cluster_shape_vmnk: Shape + :param internal_type: An optional parameter for the internal data type to when element + type does not match the copy type + :type internal_type: Type[Numeric] + :return: A Copy Atom for this Operation and the associated TMA tensor + :rtype: Tuple[core.CopyAtom, Tensor] + + """ + + if internal_type is not None: + if not isinstance(internal_type, NumericMeta): + raise TypeError(f"internal_type must be a Numeric, but got {internal_type}") + internal_type = internal_type.mlir_type + check_type_in( + op, + [CopyBulkTensorTileG2SOp, CopyBulkTensorTileG2SMulticastOp], + "op", + "make_tiled_tma_atom_B", + ) + + ident = core.make_identity_layout(gmem_tensor.shape, loc=loc, ip=ip) + mma_tiler_nk = (mma_tiler_mnk[1], *mma_tiler_mnk[2:]) + g_tile = core.composition(ident, mma_tiler_nk, loc=loc, ip=ip) + cta_v_map = tiled_mma._thrfrg_B(g_tile) + cta_v_map = core.get(cta_v_map, mode=[1]) + cta_v_map = core.dice(cta_v_map, (1, (1,) * core.rank(g_tile))) + + if isinstance(op, CopyBulkTensorTileG2SOp): + num_multicast = 1 + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + # multicast across the M-mode since those would share the same tile of B + num_multicast = core.size(cluster_shape_vmnk, mode=[1]) + + # res[0] = the IR Value for the non-executable atom instance + # res[1] = the IR Value for the associated TMA tensor + res = _cute_nvgpu_ir.atom_make_non_exec_tiled_tma_load( + gmem_tensor.value, + smem_layout, + cta_v_map, + op._to_ir(), + num_multicast=num_multicast, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + if isinstance(op, CopyBulkTensorTileG2SOp): + return core.CopyAtom(op, CopyBulkTensorTileG2SNonExecTrait(res[0])), res[1] + else: + assert isinstance(op, CopyBulkTensorTileG2SMulticastOp) + return ( + core.CopyAtom(op, CopyBulkTensorTileG2SMulticastNonExecTrait(res[0])), + res[1], + ) + + +__all__ = [ + "make_tiled_tma_atom_A", + "make_tiled_tma_atom_B", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2831bec6039b86a2231a5f05bdd3d1b9e0d891b0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/__init__.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .copy import * +from .mma import * +from .helpers import * + +# __all__ is required here for documentation generation +__all__ = [ + # + # copy.py + # + "Repetition", + "Pack", + "Unpack", + "Ld16x64bOp", + "Ld16x128bOp", + "Ld16x256bOp", + "Ld16x32bx2Op", + "Ld32x32bOp", + "St16x64bOp", + "St16x128bOp", + "St16x256bOp", + "St16x32bx2Op", + "St32x32bOp", + # + # mma.py + # + "OperandMajorMode", + "OperandSource", + "CtaGroup", + "Field", + "MmaTF32Op", + "MmaF16BF16Op", + "MmaI8Op", + "MmaFP8Op", + "MmaMXF8Op", + "MmaMXF4Op", + "MmaMXF4NVF4Op", + "SmemLayoutAtomKind", + # + # helpers.py + # + "make_smem_layout_atom", + "tile_to_mma_shape", + "commit", + "is_tmem_load", + "is_tmem_store", + "get_tmem_copy_properties", + "find_tmem_tensor_col_offset", + "make_tmem_copy", + "make_s2t_copy", + "get_s2t_smem_desc_tensor", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..df954b09d5bcd30321df0dd65a9955fd30a0e811 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -0,0 +1,663 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Type + +from cutlass.cutlass_dsl import CuTeDSL + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ...core import CopyOp, Trait +from ...typing import Numeric + +from .mma import CtaGroup + + +class Repetition(enum.Enum): + """ + An enumeration for the number of repetitions of a given TMEM copy within the instruction. + """ + + x1 = 1 + x2 = 2 + x4 = 4 + x8 = 8 + x16 = 16 + x32 = 32 + x64 = 64 + x128 = 128 + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + @classmethod + def _missing_(cls, value): + if isinstance(value, int): + if value == 1: + return Repetition.x1 + elif value == 2: + return Repetition.x2 + elif value == 8: + return Repetition.x8 + elif value == 16: + return Repetition.x16 + elif value == 32: + return Repetition.x32 + elif value == 64: + return Repetition.x64 + elif value == 128: + return Repetition.x128 + + +class Pack(enum.Enum): + """ + An enumeration for the possible packing patterns for TMEM to RMEM copies. + """ + + NONE = enum.auto() + PACK_16b_IN_32b = enum.auto() + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + +class Unpack(enum.Enum): + """ + An enumeration for the possible unpacking patterns for RMEM to TMEM copies. + """ + + NONE = enum.auto() + UNPACK_32b_IN_16b = enum.auto() + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + +@dataclass(frozen=True) +class _LdBase(CopyOp): + repeat: Repetition = Repetition.x1 + pack: Pack = Pack.NONE + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + if not isinstance(self.repeat, Repetition): + raise OpError( + self, + "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance", + ) + if not isinstance(self.pack, Pack): + raise OpError( + self, + "expects the 'pack' Op parameter to be a tcgen05.Pack instance", + ) + + def __str__(self) -> str: + res = ( + f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + + f"\n number of repetitions = {self.repeat.value}" + ) + if self.pack == Pack.PACK_16b_IN_32b: + res += f"\n with 2x 16-bit to 32b packing" + return res + + +@dataclass(frozen=True) +class Ld16x64bOp(_LdBase): + """ + 16x64b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x64b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x64bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 64, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x64bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Ld16x128bOp(_LdBase): + """ + 16x128b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x128b`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat == Repetition.x128: + raise OpError( + self, + "x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32, x64", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 128, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Ld16x256bOp(_LdBase): + """ + 16x256b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x256b`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat in (Repetition.x128, Repetition.x64): + raise OpError( + self, + "x64 and x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 256, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Ld16x32bx2Op(_LdBase): + """ + 16x32bx2 TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x32bx2`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld16x32bx2Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 16, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld16x32bx2Trait(Trait): + pass + + +@dataclass(frozen=True) +class Ld32x32bOp(_LdBase): + """ + 32x32b TMEM load Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.32x32`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Ld32x32bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemLoadType.get( + copy_internal_type.mlir_type, + 32, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.pack == Pack.PACK_16b_IN_32b else None, + ) + return Ld32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Ld32x32bTrait(Trait): + pass + + +@dataclass(frozen=True) +class _StBase(CopyOp): + repeat: Repetition + unpack: Unpack = Unpack.NONE + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + + if not isinstance(self.repeat, Repetition): + raise OpError( + self, + "expects the 'repeat' Op parameter to be a tcgen05.Repetition instance", + ) + if not isinstance(self.unpack, Unpack): + raise OpError( + self, + "expects the 'pack' Op parameter to be a tcgen05.Unpack instance", + ) + + def __str__(self) -> str: + res = ( + f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + + f"\n number of repetitions = {self.repeat.value}" + ) + if self.unpack == Unpack.UNPACK_32b_IN_16b: + res += f"\n with 32-bit to 2x 16b unpacking" + return res + + +@dataclass(frozen=True) +class St16x64bOp(_StBase): + """ + 16x64b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x64`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x64bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 64, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x64bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x64bTrait(Trait): + pass + + +@dataclass(frozen=True) +class St16x128bOp(_StBase): + """ + 16x128b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x128`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat == Repetition.x128: + raise OpError( + self, + "x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32, x64", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 128, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class St16x256bOp(_StBase): + """ + 16x256b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x256`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.repeat in (Repetition.x128, Repetition.x64): + raise OpError( + self, + "x64 and x128 repetition is not supported", + suggestion="choose one of x1, x2, x4, x8, x16, x32", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 256, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class St16x32bx2Op(_StBase): + """ + 16x32x2b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.16x32x2`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St16x32bx2Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 16, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St16x32bx2Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St16x32bx2Trait(Trait): + pass + + +@dataclass(frozen=True) +class St32x32bOp(_StBase): + """ + 32x32b TMEM store Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.32x32`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "St32x32bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100TmemStoreType.get( + copy_internal_type.mlir_type, + 32, + 32, + self.repeat.value, + ir.UnitAttr.get() if self.unpack == Unpack.UNPACK_32b_IN_16b else None, + ) + return St32x32bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class St32x32bTrait(Trait): + pass + + +@dataclass(frozen=True) +class _S2TCopyBase(CopyOp): + cta_group: CtaGroup + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Arch verification + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + + def __str__(self) -> str: + res = ( + f"tcgen05 {self.__class__.__name__[:-2]} Copy Operation" + + f"\n CTA group = {self.cta_group}" + ) + + return res + + +@dataclass(frozen=True) +class Cp128x256bOp(_S2TCopyBase): + """ + 128x256b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.128x256b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp128x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 128, + 256, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp128x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp128x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp128x128bOp(_S2TCopyBase): + """ + 128x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.128x128b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp128x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 128, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp128x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp128x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp4x256bOp(_S2TCopyBase): + """ + 4x256b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.4x256b`` qualifier. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp4x256bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 4, + 256, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.none, + ) + return Cp4x256bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp4x256bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp4x32x128bOp(_S2TCopyBase): + """ + 32x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.32x128b`` qualifier with ``warpx4`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp4x32x128bTrait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 32, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.x4, + ) + return Cp4x32x128bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp4x32x128bTrait(Trait): + pass + + +@dataclass(frozen=True) +class Cp2x64x128b0213Op(_S2TCopyBase): + """ + 64x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::02_13`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp2x64x128b0213Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 64, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.lw_0213, + ) + return Cp2x64x128b0213Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp2x64x128b0213Trait(Trait): + pass + + +@dataclass(frozen=True) +class Cp2x64x128b0123Op(_S2TCopyBase): + """ + 64x128b SMEM to TMEM Copy Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.64x128b`` qualifier with ``.warpx2::01_23`` broadcast qualifier enabled. + """ + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "Cp2x64x128b0123Trait": + ty = _cute_nvgpu_ir.CopyAtomSM100CopyS2TType.get( + copy_internal_type.mlir_type, + 64, + 128, + self.cta_group.value, + _cute_nvgpu_ir.CopyS2TBroadcast.lw_0123, + ) + return Cp2x64x128b0123Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class Cp2x64x128b0123Trait(Trait): + pass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad27e62962e874da6707ac8a36863d5ed8f98a4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/helpers.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import overload, Type, Tuple, Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir.dialects import nvvm + +from ...typing import ( + Shape, + IntTuple, + Layout, + Tensor, + Int, + Numeric, + NumericMeta, + Int16, + Int32, +) +from ... import core +from .mma import SmemLayoutAtomKind, CtaGroup +from .copy import ( + Pack, + Unpack, + Ld16x64bOp, + Ld16x128bOp, + Ld16x256bOp, + Ld16x32bx2Op, + Ld32x32bOp, + St16x64bOp, + St16x128bOp, + St16x256bOp, + St16x32bx2Op, + St32x32bOp, +) + + +#################################################################################################### +# +# Helper functions for MMA +# +#################################################################################################### + + +@dsl_user_op +def make_smem_layout_atom( + kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None +) -> core.ComposedLayout: + """ + Makes a SMEM layout Atom. + + This function creates a composed layout in unit of elements consistent with the requested layout + Atom kind and element data type. + + :param kind: The kind of layout Atom + :type kind: SmemLayoutAtomKind + :param element_type: The element data type to construct the layout for + :type element_type: Type[Numeric] + :return: The SMEM layout atom + :rtype: core.ComposedLayout + """ + if not isinstance(element_type, NumericMeta): + raise TypeError(f"element_type must be a Numeric, but got {element_type}") + + if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER): + num_contiguous_bits = 128 + sw = core.make_swizzle(0, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32): + num_contiguous_bits = 256 + sw = core.make_swizzle(1, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64): + num_contiguous_bits = 512 + sw = core.make_swizzle(2, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128): + num_contiguous_bits = 1024 + sw = core.make_swizzle(3, 4, 3) + elif kind == SmemLayoutAtomKind.MN_SW128_32B: + num_contiguous_bits = 1024 + sw = core.make_swizzle(2, 5, 2) + else: + raise ValueError("unrecognized SMEM layout atom kind") + num_contiguous_elems = num_contiguous_bits // element_type.width + + if kind in ( + SmemLayoutAtomKind.MN_INTER, + SmemLayoutAtomKind.MN_SW32, + SmemLayoutAtomKind.MN_SW64, + SmemLayoutAtomKind.MN_SW128, + SmemLayoutAtomKind.MN_SW128_32B, + ): + # M/N-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (num_contiguous_elems, 8), stride=(1, num_contiguous_elems) + ), + loc=loc, + ip=ip, + ) + else: + # K-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (8, num_contiguous_elems), stride=(num_contiguous_elems, 1) + ), + loc=loc, + ip=ip, + ) + + +@overload +def tile_to_mma_shape( + atom: Layout, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None +) -> Layout: ... + + +@overload +def tile_to_mma_shape( + atom: core.ComposedLayout, + mma_tile_shape: Shape, + order: IntTuple = None, + *, + loc=None, + ip=None, +) -> core.ComposedLayout: ... + + +@dsl_user_op +def tile_to_mma_shape( + atom, mma_tile_shape: Shape, order: IntTuple = None, *, loc=None, ip=None +): + """ + Tiles a layout to an MMA shape. + """ + # Default order is colexicographical + if order is None: + order = tuple(range(core.rank(mma_tile_shape) - 1)) + if core.rank(order) != core.rank(mma_tile_shape) - 1: + raise ValueError( + f"rank(order)={core.rank(order)} must be equal to " + f"rank(mma_tile_shape)-1={core.rank(mma_tile_shape)-1}" + ) + order_val = core._pack_int_tuple(order, loc=loc, ip=ip) + mma_tile_shape_val = core._pack_shape(mma_tile_shape, loc=loc, ip=ip) + + if not ( + core.is_static(atom) + and core.is_static(mma_tile_shape_val) + and core.is_static(order_val) + ): + raise ValueError("tile_to_mma_shape only supports static inputs") + + res_ty = _cute_nvgpu_ir.tile_to_mma_shape(atom, mma_tile_shape_val, order_val) + return _cute_ir.static(res_ty, loc=loc, ip=ip) + + +@dsl_user_op +def commit( + mbar_ptr: core.Pointer, + mask=None, + cta_group: CtaGroup = CtaGroup.ONE, + *, + loc=None, + ip=None, +) -> None: + """ + Perform an arrive operation on a mbarrier upon completion of previous MMA operations. + + :param mbar_ptr: A pointer to the mbarrier in SMEM + :type mbar_ptr: Pointer + :param mask: An optional multicast mask for the CTAs in the cluster to signal arrival to + :type mask: Int + """ + if cta_group == CtaGroup.ONE: + group = nvvm.Tcgen05GroupKind.CTA_1 + else: + assert cta_group == CtaGroup.TWO + group = nvvm.Tcgen05GroupKind.CTA_2 + + mbar_ptr = mbar_ptr.llvm_ptr + if mask is not None: + mask = Int16(mask).ir_value(loc=loc, ip=ip) + nvvm.tcgen05_commit_arrive( + mbar_ptr, multicast_mask=mask, group=group, loc=loc, ip=ip + ) + else: + nvvm.tcgen05_commit_arrive(mbar_ptr, group=group, loc=loc, ip=ip) + return + + +#################################################################################################### +# +# Helper functions for Copies +# +#################################################################################################### + + +def is_tmem_load(atom: core.CopyAtom) -> bool: + """ + Returns whether a CopyAtom instance is a TMEM load. + """ + return isinstance( + atom.op, + ( + Ld16x64bOp, + Ld16x128bOp, + Ld16x256bOp, + Ld16x32bx2Op, + Ld32x32bOp, + ), + ) + + +def is_tmem_store(atom: core.CopyAtom) -> bool: + """ + Returns whether a CopyAtom instance is a TMEM store. + """ + return isinstance( + atom.op, + ( + St16x64bOp, + St16x128bOp, + St16x256bOp, + St16x32bx2Op, + St32x32bOp, + ), + ) + + +def get_tmem_copy_properties( + atom: core.CopyAtom, +) -> Tuple[int, int, int, Union[Pack, Unpack]]: + """ + Returns the properties of a TMEM copy atom (number of data paths, bits, repetitions, + and whether packing/unpacking is used). + """ + if isinstance(atom.op, (Ld16x64bOp, St16x64bOp)): + num_dp, num_bits = 16, 64 + elif isinstance(atom.op, (Ld16x128bOp, St16x128bOp)): + num_dp, num_bits = 16, 128 + elif isinstance(atom.op, (Ld16x256bOp, St16x256bOp)): + num_dp, num_bits = 16, 256 + elif isinstance(atom.op, (Ld16x32bx2Op, St16x32bx2Op)): + num_dp, num_bits = 16, 32 + elif isinstance(atom.op, (Ld32x32bOp, St32x32bOp)): + num_dp, num_bits = 32, 32 + else: + raise ValueError(f"expects 'atom' to be a TMEM copy, but got {atom}") + if is_tmem_load(atom): + return num_dp, num_bits, atom.op.repeat.value, atom.op.pack + else: + assert is_tmem_store(atom), "atom must be a TMEM store" + return num_dp, num_bits, atom.op.repeat.value, atom.op.unpack + + +@dsl_user_op +def find_tmem_tensor_col_offset(tmem_tensor: Tensor, *, loc=None, ip=None) -> Int: + """ + Computes the TMEM column offset given a TMEM tensor. + + :param tmem_tensor: The TMEM tensor to use to compute the columns offset + :type tmem_tensor: Tensor + :return: The columns offset + :rtype: Int + """ + tmem_col_mask = 0x0000FFFF + offset = ( + core.cosize(core.recast_tensor(tmem_tensor, Int32).layout, loc=loc, ip=ip) + & tmem_col_mask + ) + if isinstance(offset, int): + return offset + return Int32(offset, loc=loc, ip=ip) + + +@dsl_user_op +def make_tmem_copy( + atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None +) -> core.TiledCopy: + """ + Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. + """ + tiled_copy_val = _cute_nvgpu_ir.atom_make_tmem_copy( + atom._trait.value, tmem_tensor.value, loc=loc, ip=ip + ) + new_trait = type(atom._trait)(tiled_copy_val) + return core.TiledCopy(atom.op, new_trait) + + +@dsl_user_op +def make_s2t_copy( + atom: core.CopyAtom, tmem_tensor: Tensor, *, loc=None, ip=None +) -> core.TiledCopy: + """ + Makes a Tiled Copy instance from a TMEM Copy Atom and a TMEM tensor. + """ + tiled_copy_val = _cute_nvgpu_ir.atom_make_s2t_copy( + atom._trait.value, tmem_tensor.value, loc=loc, ip=ip + ) + new_trait = type(atom._trait)(tiled_copy_val) + return core.TiledCopy(atom.op, new_trait) + + +@dsl_user_op +def get_s2t_smem_desc_tensor( + atom: core.CopyAtom, smem_tensor: Tensor, *, loc=None, ip=None +) -> Tensor: + """ + Returns the SMEM descriptor tensor from a S2T copy atom and a SMEM tensor. + """ + smem_desc_tensor = _cute_nvgpu_ir.atom_get_copy_s2t_smem_desc_view( + atom._trait.value, smem_tensor.value, loc=loc, ip=ip + ) + return smem_desc_tensor diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py new file mode 100644 index 0000000000000000000000000000000000000000..3a938523e130cf551c205669164e15e8bbd29132 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -0,0 +1,1041 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Type + +from cutlass.cutlass_dsl import CuTeDSL, T + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ... import core +from ...core import Trait, _pack_shape, rank, depth, _Tensor +from ...typing import ( + Shape, + Float4E2M1FN, + Float8E8M0FNU, + Float8E5M2, + Float8E4M3FN, + Float16, + BFloat16, + Float32, + TFloat32, + Boolean, + Int8, + Uint8, + Int32, + Numeric, + AddressSpace, + Pointer, +) + + +#################################################################################################### +# +# MMA Ops and Traits +# +#################################################################################################### + + +class OperandMajorMode(enum.Enum): + """ + An enumeration for the majorness of the input operands of the MMA. + """ + + MN = _cute_ir.MajorMode.mn + K = _cute_ir.MajorMode.k + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + value = value.upper() + if value == "MN": + return OperandMajorMode.MN + elif value == "K": + return OperandMajorMode.K + + def _to_ir(self) -> _cute_ir.MajorMode: + return self.value + + +class OperandSource(enum.Enum): + """ + An enumeration for the source memory location of the A input operand of the MMA. + """ + + TMEM = _cute_ir.MmaFragKind.tmem + SMEM = _cute_ir.MmaFragKind.smem_desc + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MmaFragKind: + return self.value + + +class CtaGroup(enum.Enum): + """ + An enumeration for the ``cta_group`` qualifier of the MMA. + """ + + ONE = 1 + TWO = 2 + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + +class Field(enum.Enum): + """ + An enumeration for the fields of the MMA Atom that can be modified at runtime. + """ + + NEGATE_A = "neg_a" + NEGATE_B = "neg_b" + ACCUMULATE = "accum_c" + SFA = "sf_a" + SFB = "sf_b" + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir_field_name(self) -> str: + return self.value + + +# Base class for all tcgen05 MMA Ops with syntax `tcgen05.mma.cta_group.kind` used to factor out some internal code +@dataclass(frozen=True) +class MmaOp(core.MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + cta_group: CtaGroup + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + + admissible_archs = [ + "sm_100a", + "sm_100f", + ] + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + # Verify the instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + if self.cta_group == CtaGroup.ONE: + if m not in [64, 128]: + raise OpError(self, f"expects the M-mode to be 64 or 128, but got {m}") + if m == 64: + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", + ) + elif m == 128: + if (n < 16) or (n > 256) or (n % 16 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 16 == 0, but got {n}", + ) + else: + if m not in [128, 256]: + raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") + if (n < 32) or (n > 256) or (n % 32 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 32 <= N <= 256 and N % 32 == 0, but got {n}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n CTA group = {self.cta_group}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class MmaTrait(Trait): + admissible_fields = [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + field_name = f"#cute_nvgpu.atom_mma_field_sm100<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + +# Base class for all tcgen05 BlockScaled MMA Ops with syntax `tcgen05.mma.cta_group.kind.block_scale` used to factor out some internal code +@dataclass(frozen=True) +class BlockScaledMmaOp(core.MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Float32 + sf_dtype: Type[Numeric] + sf_vec_size: int + shape_mnk: Shape + cta_group: CtaGroup + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + + admissible_archs = [ + "sm_100a", + ] + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.cta_group, CtaGroup): + raise OpError( + self, + "expects the 'cta_group' Op parameter to be a tcgen05.CtaGroup instance", + ) + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a tcgen05.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a tcgen05.OperandMajorMode instance", + ) + # Verify the instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + if self.cta_group == CtaGroup.ONE: + if m != 128: + raise OpError(self, f"expects the M-mode to be 128, but got {m}") + + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got {n}", + ) + else: + if m not in [128, 256]: + raise OpError(self, f"expects the M-mode to be 128 or 256, but got {m}") + if (n < 16) or (n > 256) or (n % 16 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 16 <= N <= 256 and N % 16 == 0, but got {n}", + ) + if self.sf_vec_size not in [16, 32]: + raise OpError( + self, + f"expects the scale factor vector size to be 16 or 32, but got {self.sf_vec_size}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Scale factor data type = {self.sf_dtype}" + + f"\n Scale factor vector size = {self.sf_vec_size}" + + f"\n CTA group = {self.cta_group}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class BlockScaledMmaTraits(Trait): + admissible_fields = [ + Field.ACCUMULATE, + Field.NEGATE_A, + Field.NEGATE_B, + Field.SFA, + Field.SFB, + ] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"expects field to be one of {self.admissible_fields}, but got {field}" + ) + if field in [Field.ACCUMULATE, Field.NEGATE_A, Field.NEGATE_B]: + value = Boolean(value).ir_value(loc=loc, ip=ip) + elif field in [Field.SFA, Field.SFB]: + if not isinstance(value, Pointer): + raise ValueError( + f"expects value to be a pointer for {field}, but got {type(value).__name__}" + ) + value = value.value + + field_name = f"#cute_nvgpu.atom_mma_field_sm100_block_scaled<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, value, loc=loc, ip=ip + ) + + +# +# TF32 MMA +# + + +@dataclass(frozen=True) +class MmaTF32Op(MmaOp): + """ + TF32 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::tf32`` qualifier. + """ + + descriptive_name = "tcgen05 TF32 MMA Operation" + + def __init__( + self, + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + TFloat32, + TFloat32, + Float32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Verify the instruction shape + instruction_k = 8 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaTF32Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + 0, + ) + return MmaTF32Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaTF32Trait(MmaTrait): + pass + + +# +# F16/BF16 MMA +# + + +@dataclass(frozen=True) +class MmaF16BF16Op(MmaOp): + """ + F16/BF16 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::f16`` qualifier. + """ + + descriptive_name = "tcgen05 F16/BF16 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + acc_dtype, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Instruction shape verification + instruction_k = 16 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + 0, + ) + return MmaF16BF16Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaF16BF16Trait(MmaTrait): + pass + + +# +# I8 MMA +# + + +@dataclass(frozen=True) +class MmaI8Op(MmaOp): + """ + I8 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::i8`` qualifier. + """ + + descriptive_name = "tcgen05 I8 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + Int32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Int8, Uint8]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Int8 or Uint8", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Instruction shape verification + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaI8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + (T.si8() if self.a_dtype.signed else T.ui8()), + (T.si8() if self.b_dtype.signed else T.ui8()), + T.si32(), + self.a_src._to_ir(), + 0, + ) + return MmaI8Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaI8Trait(MmaTrait): + pass + + +# +# F8F6F4 MMA +# + + +@dataclass(frozen=True) +class MmaFP8Op(MmaOp): + """ + F8 tcgen05 MMA Operation. + + See the `PTX documentation `__. + """ + + descriptive_name = "tcgen05 F8 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + + super().__init__( + ab_dtype, + ab_dtype, + acc_dtype, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Instruction shape verification + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaFP8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMAType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + 0, + ) + return MmaFP8Trait( + _cute_nvgpu_ir.make_sm100_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaFP8Trait(MmaTrait): + pass + + +# +# MXF8F6F4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF8Op(BlockScaledMmaOp): + """ + MXF8 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf8f6f4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF8 BlockScaled MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + Float32, + Float8E8M0FNU, + 32, + instruction_shape, + cta_group, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Instruction shape verification + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF8Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF8Trait(BlockScaledMmaTraits): + pass + + +# +# MXF4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF4Op(BlockScaledMmaOp): + """ + MXF4 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF4 BlockScaled MMA Operation" + + def __init__( + self, + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + ) -> None: + super().__init__( + Float4E2M1FN, + Float4E2M1FN, + Float32, + Float8E8M0FNU, + 32, + instruction_shape, + cta_group, + a_src, + OperandMajorMode.K, + OperandMajorMode.K, + ) + self._verify() + + def _verify(self) -> None: + # Instruction shape verification + instruction_k = 64 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF4Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF4Trait(BlockScaledMmaTraits): + pass + + +# +# MXF4NVF4 MMA +# + + +@dataclass(frozen=True) +class MmaMXF4NVF4Op(BlockScaledMmaOp): + """ + MXF4NVF4 tcgen05 BlockScaled MMA Operation. + + See the `PTX documentation `__. + This Operation corresponds to the ``.kind::mxf4nvf4`` qualifier. + """ + + descriptive_name = "tcgen05 MXF4NVF4 BlockScaled MMA Operation" + + def __init__( + self, + sf_dtype: Type[Numeric], + instruction_shape: Shape, + cta_group: CtaGroup, + a_src: OperandSource, + ) -> None: + super().__init__( + Float4E2M1FN, + Float4E2M1FN, + Float32, + sf_dtype, + 16, + instruction_shape, + cta_group, + a_src, + OperandMajorMode.K, + OperandMajorMode.K, + ) + self._verify() + + def _verify(self) -> None: + # Scale Factor data type verification + if self.sf_dtype not in [Float8E8M0FNU, Float8E4M3FN]: + raise OpError( + self, + "expects the 'sf_dtype' Op parameter to be one of Float8E8M0FNU", + ) + # Instruction shape verification + instruction_k = 64 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaMXF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM100UMMABlockScaledType.get( + shape_mnk.type.attribute, + self.cta_group.value, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.sf_dtype.mlir_type, + self.a_src._to_ir(), + self.sf_vec_size, + ) + return MmaMXF4NVF4Trait( + _cute_nvgpu_ir.make_sm100_mma_bs( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + Boolean(False).ir_value(loc=loc, ip=ip), + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + core.make_ptr(self.sf_dtype, 0, _cute_ir.AddressSpace.tmem).value, + loc=loc, + ip=ip, + ) + ) + + +class MmaMXF4NVF4Trait(BlockScaledMmaTraits): + pass + +#################################################################################################### +# +# SMEM layout atoms +# +#################################################################################################### + + +class SmemLayoutAtomKind(enum.Enum): + """ + Enum class for the kinds of SMEM layout atoms for SM100. + + Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can be + used to construct an SMEM layout using blocked product for operand A or B such that the + resulting layout is legal for both TMA and UMMA. + + Note that there are other ways of creating legal layouts for operand A and B. + """ + + MN_INTER = enum.auto() + MN_SW32 = enum.auto() + MN_SW64 = enum.auto() + MN_SW128 = enum.auto() + MN_SW128_32B = enum.auto() + K_INTER = enum.auto() + K_SW32 = enum.auto() + K_SW64 = enum.auto() + K_SW128 = enum.auto() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2b3f7cf5b0698752d7ea6c450782f17a3fee797 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .copy import * +from .mma import * + + +# __all__ is required here for documentation generation +__all__ = [ + # mma.py + "MmaF16BF16Op", + # copy.py + "LdMatrix8x8x16bOp", + "LdMatrix16x16x8bOp", + "StMatrix8x8x16bOp", + "StMatrix16x8x8bOp", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ad4ca8f0e2dd05b6e779eaedec0b69cd47decf --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/copy.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass +from typing import Type + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ...core import CopyOp, Trait, _pack_shape +from ...typing import Numeric + + +@dataclass(frozen=True) +class BaseOp(CopyOp): + transpose: bool = False + num_matrices: int = 1 + + def __post_init__(self) -> None: + if not isinstance(self.transpose, bool): + raise OpError( + self, + "expects the 'transpose' Op parameter to be a bool instance", + ) + + def __str__(self) -> str: + res = ( + f"{self.__class__.__name__[:-2]} Copy Operation" + + f"\n number of matrices = {self.num_matrices}" + ) + if self.transpose: + res += f"\n transposed" + return res + + +@dataclass(frozen=True) +class LdMatrix8x8x16bOp(BaseOp): + """ + 8x8 ``ldmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``.m8n8`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.num_matrices not in [1, 2, 4]: + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2,4]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "LdMatrix8x8x16bTrait": + mode = _pack_shape((8, 8), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomLdsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + _cute_nvgpu_ir.LdsmSzPattern.u16, + self.num_matrices, + ir.UnitAttr.get() if self.transpose else None, + ) + return LdMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class LdMatrix8x8x16bTrait(Trait): + pass + + +@dataclass(frozen=True) +class LdMatrix16x16x8bOp(BaseOp): + """ + 16x16 8-bit ``ldmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers. + """ + + def __init__(self, num_matrices: int) -> None: + super().__init__(transpose=True, num_matrices=num_matrices) + self._verify() + + def _verify(self): + assert self.transpose, "transpose must be True" + if self.num_matrices not in [1, 2]: + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "LdMatrix16x16x8bTrait": + mode = _pack_shape((16, 16), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomLdsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + _cute_nvgpu_ir.LdsmSzPattern.u8, + self.num_matrices, + ir.UnitAttr.get(), + ) + return LdMatrix16x16x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class LdMatrix16x16x8bTrait(Trait): + pass + + +@dataclass(frozen=True) +class StMatrix8x8x16bOp(BaseOp): + """ + 8x8 ``stmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``m8n8`` qualifier. + """ + + def __post_init__(self) -> None: + super().__post_init__() + if self.num_matrices not in [1, 2, 4]: + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2,4]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "StMatrix8x8x16bTrait": + mode = _pack_shape((8, 8), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomStsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + self.num_matrices, + ir.UnitAttr.get() if self.transpose else None, + ) + return StMatrix8x8x16bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class StMatrix8x8x16bTrait(Trait): + pass + + +@dataclass(frozen=True) +class StMatrix16x8x8bOp(BaseOp): + """ + 16x8 ``stmatrix`` Operation. + + See the `PTX documentation `__. + This operation corresponds to the ``m16n8`` qualifier. + """ + + def __init__(self, num_matrices: int) -> None: + super().__init__(transpose=True, num_matrices=num_matrices) + self._verify() + + def _verify(self): + if self.num_matrices not in [1, 2, 4]: + assert self.transpose, "transpose must be True" + raise OpError( + self, + "expects the 'num_matrices' Op parameter to be one of [1,2,4]", + ) + + def _make_trait( + self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs + ) -> "StMatrix16x8x8bTrait": + mode = _pack_shape((16, 8), loc=loc, ip=ip) + ty = _cute_nvgpu_ir.CopyAtomStsmType.get( + copy_internal_type.mlir_type, + mode.type.attribute, + self.num_matrices, + ir.UnitAttr.get(), + ) + return StMatrix16x8x8bTrait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + +class StMatrix16x8x8bTrait(Trait): + pass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py new file mode 100644 index 0000000000000000000000000000000000000000..49df213b76f24f23ecfe5a75e36cf17d35aeb98b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -0,0 +1,83 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass +from typing import Type + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +from ..common import OpError +from ...core import MmaOp, Trait, _pack_shape, _Tensor +from ...typing import Shape, Float16, BFloat16, Float32, Numeric, AddressSpace + + +@dataclass(frozen=True) +class MmaF16BF16Op(MmaOp): + """ + F16/BF16 tcgen05 MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands. + """ + + ab_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + + def __post_init__(self) -> None: + if self.ab_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + if (self.ab_dtype == BFloat16) and (self.acc_dtype != Float32): + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16", + ) + if self.shape_mnk not in [(16, 8, 8), (16, 8, 16)]: + raise OpError( + self, + "expects the 'shape_mnk' Op parameter to be one of (16,8,8) or (16,8,16)", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM80Type.get( + shape_mnk.type.attribute, + self.ab_dtype.mlir_type, + self.ab_dtype.mlir_type, + self.acc_dtype.mlir_type, + ) + return MmaF16BF16Trait(_cute_ir.atom(ty, loc=loc, ip=ip)) + + def __str__(self) -> str: + return ( + "warp-level F16/BF16 MMA Operation" + + f"\n A/B data type = {self.ab_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + pass + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + pass + +class MmaF16BF16Trait(Trait): + pass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49a40165033024c9c9b17acd298a1f8ba055649c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .mma import * +from .helpers import * + +# __all__ is required here for documentation generation +__all__ = [ + # mma.py + "OperandMajorMode", + "OperandSource", + "Field", + "MmaF16BF16Op", + "MmaF8Op", + "SmemLayoutAtomKind", + # helpers.py + "make_smem_layout_atom", + "fence", + "commit_group", + "wait_group", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f6284134933bec170ecec5eeb0bf9f829ef0dff0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/helpers.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type + +from cutlass.cutlass_dsl import dsl_user_op + +from cutlass._mlir.dialects import nvvm + +from ...typing import Numeric, NumericMeta +from ... import core +from .mma import SmemLayoutAtomKind + + +@dsl_user_op +def make_smem_layout_atom( + kind: SmemLayoutAtomKind, element_type: Type[Numeric], *, loc=None, ip=None +) -> core.ComposedLayout: + """ + Makes a SMEM layout Atom. + + This function creates a composed layout in unit of elements consistent with the requested layout + Atom kind and element data type. + + :param kind: The kind of layout Atom + :type kind: SmemLayoutAtomKind + :param element_type: The element data type to construct the layout for + :type element_type: Type[Numeric] + :return: The SMEM layout atom + :rtype: core.ComposedLayout + """ + if not isinstance(element_type, NumericMeta): + raise TypeError(f"element_type must be a Numeric, but got {element_type}") + + if kind in (SmemLayoutAtomKind.MN_INTER, SmemLayoutAtomKind.K_INTER): + num_contiguous_bits = 128 + sw = core.make_swizzle(0, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW32, SmemLayoutAtomKind.K_SW32): + num_contiguous_bits = 256 + sw = core.make_swizzle(1, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW64, SmemLayoutAtomKind.K_SW64): + num_contiguous_bits = 512 + sw = core.make_swizzle(2, 4, 3) + elif kind in (SmemLayoutAtomKind.MN_SW128, SmemLayoutAtomKind.K_SW128): + num_contiguous_bits = 1024 + sw = core.make_swizzle(3, 4, 3) + else: + raise ValueError("unrecognized SMEM layout atom kind") + num_contiguous_elems = num_contiguous_bits // element_type.width + + if kind in ( + SmemLayoutAtomKind.MN_INTER, + SmemLayoutAtomKind.MN_SW32, + SmemLayoutAtomKind.MN_SW64, + SmemLayoutAtomKind.MN_SW128, + ): + # M/N-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (num_contiguous_elems, 8), stride=(1, num_contiguous_elems) + ), + loc=loc, + ip=ip, + ) + else: + # K-major layout + return core.make_composed_layout( + sw, + 0, + core.make_layout( + (8, num_contiguous_elems), stride=(num_contiguous_elems, 1) + ), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def fence(*, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + nvvm.wgmma_fence_aligned(loc=None, ip=None) + + +@dsl_user_op +def commit_group(*, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + nvvm.wgmma_commit_group_sync_aligned(loc=loc, ip=ip) + + +@dsl_user_op +def wait_group(group, *, loc=None, ip=None) -> None: + """ + See the `PTX documentation `__. + """ + nvvm.wgmma_wait_group_sync_aligned(group, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py new file mode 100644 index 0000000000000000000000000000000000000000..275861f70cc3d6eca932cb263890aaaa4121445f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/nvgpu/warpgroup/mma.py @@ -0,0 +1,405 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from dataclasses import dataclass +from typing import Type + +from cutlass.cutlass_dsl import CuTeDSL + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +from cutlass._mlir import ir + +from ..common import OpError +from ...core import MmaOp, Trait, _pack_shape, rank, depth, _Tensor +from ...typing import ( + Shape, + Float16, + BFloat16, + Float32, + Boolean, + Float8E5M2, + Float8E4M3FN, + Numeric, + AddressSpace, +) + + +#################################################################################################### +# +# MMA Ops and Traits +# +#################################################################################################### + + +class OperandMajorMode(enum.Enum): + """ + An enumeration for the majorness of the input operands of the MMA. + """ + + MN = _cute_ir.MajorMode.mn + K = _cute_ir.MajorMode.k + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + value = value.upper() + if value == "MN": + return OperandMajorMode.MN + elif value == "K": + return OperandMajorMode.K + + def _to_ir(self) -> _cute_ir.MajorMode: + return self.value + + +class OperandSource(enum.Enum): + """ + An enumeration for the source memory location of the A input operand of the MMA. + """ + + RMEM = _cute_ir.MmaFragKind.rmem + SMEM = _cute_ir.MmaFragKind.smem_desc + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir(self) -> _cute_ir.MmaFragKind: + return self.value + + +class Field(enum.Enum): + """ + An enumeration for the fields of the MMA Atom that can be modified at runtime. + """ + + ACCUMULATE = "accum_c" + + def __str__(self) -> str: + return f"{self.__class__.__name__}.{self.name}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}.{self.name}>" + + def _to_ir_field_name(self) -> str: + return self.value + + +@dataclass(frozen=True) +class MmaOp(MmaOp): + a_dtype: Type[Numeric] + b_dtype: Type[Numeric] + acc_dtype: Type[Numeric] + shape_mnk: Shape + a_src: OperandSource + a_major_mode: OperandMajorMode + b_major_mode: OperandMajorMode + + admissible_archs = ["sm_90a"] + + def __post_init__(self) -> None: + # Verify arch + arch = CuTeDSL._get_dsl().envar.arch + if arch not in self.admissible_archs: + raise OpError( + self, + f"expects arch to be one of {self.admissible_archs}, but got {arch}", + suggestion="Ensure env CUTE_DSL_ARCH matches your GPU architecture", + ) + # Verify that the user provided enum values + if not isinstance(self.a_src, OperandSource): + raise OpError( + self, + "expects the 'a_src' Op parameter to be a warpgroup.OperandSource instance", + ) + if not isinstance(self.a_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'a_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", + ) + if not isinstance(self.b_major_mode, OperandMajorMode): + raise OpError( + self, + "expects the 'b_major_mode' Op parameter to be a warpgroup.OperandMajorMode instance", + ) + # Verify instruction shape + if (rank(self.shape_mnk) not in [2, 3]) or (depth(self.shape_mnk) != 1): + raise OpError( + self, + f"expected a flat rank 2 or 3 tuple for the 'shape_mnk' Op parameter, " + f"but got {self.shape_mnk}", + ) + m, n = self.shape_mnk[0], self.shape_mnk[1] + if m != 64: + raise OpError(self, f"expects the M-mode to be 64, but got {m}") + if (n < 8) or (n > 256) or (n % 8 != 0): + raise OpError( + self, + f"expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0. but got {n}", + ) + + def __str__(self) -> str: + return ( + self.__class__.descriptive_name # type: ignore + + f"\n A data type = {self.a_dtype}" + + f"\n B data type = {self.b_dtype}" + + f"\n Accumulator data type = {self.acc_dtype}" + + f"\n A source location = {self.a_src}" + + f"\n A major mode = {self.a_major_mode}" + + f"\n B major mode = {self.b_major_mode}" + + f"\n Instruction shape MNK = {self.shape_mnk}" + ) + + def _verify_fragment_A(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand A, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + def _verify_fragment_B(self, input: _Tensor, *, loc=None, ip=None): + if input.memspace == AddressSpace.smem and isinstance( + input.layout.type, _cute_ir.ComposedLayoutType + ): + raise OpError( + self, + f"Expected affine layout for {self._make_trait()}'s operand B, " + f"but got composed layout instead: {input.layout}" + f"\nPlease use recast_ptr(ptr, {input.layout.inner}, element_type) operation to move swizzle to the ptr", + ) + return True + + +class MmaTrait(Trait): + admissible_fields = [Field.ACCUMULATE] + + def set(self, field, value, *, loc=None, ip=None) -> None: + if field not in self.admissible_fields: + raise ValueError( + f"invalid field, must be {Field.ACCUMULATE}, but got {field}" + ) + field_name = f"#cute_nvgpu.atom_mma_field_sm90<{field._to_ir_field_name()}>" + attr = ir.Attribute.parse(field_name) + self.value = _cute_nvgpu_ir.atom_set_value( + self.value, attr, Boolean(value).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + + +@dataclass(frozen=True) +class MmaF16BF16Op(MmaOp): + """ + F16/BF16 warpgroup MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.f16`` or ``.bf16`` qualifiers for the input operands. + """ + + descriptive_name = "warpgroup F16/BF16 MMA Operation" + + def __init__( + self, + ab_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + ab_dtype, + ab_dtype, + acc_dtype, + instruction_shape, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self) -> None: + # Input data type verification + if self.a_dtype not in [Float16, BFloat16]: + raise OpError( + self, + "expects the 'ab_dtype' Op parameter to be one of Float16 or BFloat16", + ) + assert self.b_dtype == self.a_dtype, "a_dtype and b_dtype must be the same" + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + if (self.a_dtype == BFloat16) and (self.acc_dtype != Float32): + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be Float32 when 'ab_dtype' is BFloat16", + ) + # Verify the instruction shape + instruction_k = 16 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF16BF16Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( + shape_mnk.type.attribute, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + ) + return MmaF16BF16Trait( + _cute_nvgpu_ir.make_sm90_mma( + ty, + Boolean(False).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +class MmaF16BF16Trait(MmaTrait): + pass + + +@dataclass(frozen=True) +class MmaF8Op(MmaOp): + """ + F16/BF16 warpgroup MMA Operation. + + See the `PTX documentation `__. + This Operation covers the instructions using the ``.e4m3`` or ``.e5m2`` qualifiers for the input operands. + """ + + descriptive_name = "warpgroup F8 MMA Operation" + + def __init__( + self, + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + acc_dtype: Type[Numeric], + instruction_shape: Shape, + a_src: OperandSource, + a_major_mode: OperandMajorMode, + b_major_mode: OperandMajorMode, + ) -> None: + super().__init__( + a_dtype, + b_dtype, + acc_dtype, + instruction_shape, + a_src, + a_major_mode, + b_major_mode, + ) + self._verify() + + def _verify(self): + # Input data type verification + if self.a_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'a_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + if self.b_dtype not in [Float8E5M2, Float8E4M3FN]: + raise OpError( + self, + "expects the 'b_dtype' Op parameter to be one of Float8E5M2 or Float8E4M3FN", + ) + # Accumulator data type verification + if self.acc_dtype not in [Float16, Float32]: + raise OpError( + self, + "expects the 'acc_dtype' Op parameter to be one of Float16 or Float32", + ) + # Verify the instruction shape + instruction_k = 32 + if rank(self.shape_mnk) == 2: + object.__setattr__(self, "shape_mnk", (*self.shape_mnk, instruction_k)) + if self.shape_mnk[2] != instruction_k: + raise OpError( + self, + f"expects the instruction extent in the K-mode to be {instruction_k}, " + f"but got {self.shape_mnk[2]}", + ) + + def _make_trait(self, *, loc=None, ip=None, **kwargs) -> "MmaF8Trait": + shape_mnk = _pack_shape(self.shape_mnk, loc=loc, ip=ip) + ty = _cute_nvgpu_ir.MmaAtomSM90Type.get( + shape_mnk.type.attribute, + self.a_major_mode._to_ir(), + self.b_major_mode._to_ir(), + self.a_dtype.mlir_type, + self.b_dtype.mlir_type, + self.acc_dtype.mlir_type, + self.a_src._to_ir(), + ) + return MmaF8Trait( + _cute_nvgpu_ir.make_sm90_mma( + ty, Boolean(False).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) + + +class MmaF8Trait(MmaTrait): + pass + + +#################################################################################################### +# +# SMEM layout atoms +# +#################################################################################################### + + +class SmemLayoutAtomKind(enum.Enum): + """ + Enum class for the kinds of SMEM layout atoms for SM90. + + Given a swizzle kind, an SMEM layout atom is the compact layout of smallest size that can + be used to construct an SMEM layout using blocked product for operand A or B such that the + resulting layout is legal for both TMA and UMMA. + + Note that there are other ways of creating legal layouts for operand A and B. + """ + + MN_INTER = enum.auto() + MN_SW32 = enum.auto() + MN_SW64 = enum.auto() + MN_SW128 = enum.auto() + K_INTER = enum.auto() + K_SW32 = enum.auto() + K_SW64 = enum.auto() + K_SW128 = enum.auto() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..9128c67a24a7202713c354fb99b2891542f0c887 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/runtime.py @@ -0,0 +1,510 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import ctypes +from functools import lru_cache +import itertools +import operator +from time import time +from typing import Union + +# MLIR modules imports +from cutlass._mlir import ir +import cutlass._mlir.dialects.cute as _cute_ir + +from cutlass.base_dsl.dsl import is_dynamic_expression +from cutlass.cutlass_dsl import JitArgAdapterRegistry + +# Local modules imports +from .typing import ( + AddressSpace, + Tensor, + Type, + Pointer, + Boolean, + Numeric, + Float4E2M1FN, + Int64, + Int32, + Int16, + Int8, + Uint64, + Uint32, + Uint16, + Uint8, + Float64, + Float32, + Float16, + BFloat16, + Float8E5M2, +) +from . import core +from .core import _Tensor as CoreTensor + + +class _Pointer(Pointer): + """Runtime representation of a pointer that can inter-operate with various data structures, + including numpy arrays and device memory. + + :param pointer: The pointer to the data + :type pointer: int or pointer-like object + :param dtype: Data type of the elements pointed to + :type dtype: Type + :param mem_space: Memory space where the pointer resides, defaults to generic + :type mem_space: _cute_ir.AddressSpace, optional + :param assumed_align: Assumed alignment of input pointer in bytes, defaults to None + :type assumed_align: int, optional + + :ivar _pointer: The underlying pointer + :ivar _dtype: Data type of the elements + :ivar _addr_space: Memory space of the pointer + :ivar _assumed_align: Alignment of the pointer in bytes + :ivar _desc: C-type descriptor for the pointer + :ivar _c_pointer: C-compatible pointer representation + """ + + def __init__( + self, + pointer, + dtype, + mem_space: _cute_ir.AddressSpace = _cute_ir.AddressSpace.generic, + assumed_align=None, + ): + self._pointer = pointer + self._dtype = dtype + self._addr_space = mem_space + + if assumed_align is None: + self._assumed_align = dtype.width // 8 + else: + self._assumed_align = assumed_align + + self._c_pointer = None + assert ( + int(self._pointer) % self._assumed_align == 0 + ), f"pointer must be {self._assumed_align} bytes aligned" + + def size_in_bytes(self) -> int: + self._desc = ctypes.c_void_p(int(self._pointer)) + return ctypes.sizeof(self._desc) + + def __get_mlir_types__(self): + return [self.mlir_type] + + def __c_pointers__(self): + if self._c_pointer is None: + self._desc = ctypes.c_void_p(int(self._pointer)) + self._c_pointer = ctypes.addressof(self._desc) + return [self._c_pointer] + + def __new_from_mlir_values__(self, values): + assert len(values) == 1 + return values[0] + + def __extract_mlir_values__(self): + return [self._c_pointer] + + # Move mlir Type out of __init__ to decouple with mlir Context + @property + def mlir_type(self) -> ir.Type: + return _cute_ir.PtrType.get( + self._dtype.mlir_type, self._addr_space, self._assumed_align + ) + + @property + def dtype(self) -> Type[Numeric]: + return self._dtype + + @property + def memspace(self): + return self._addr_space + + def align(self, min_align: int, *, loc=None, ip=None) -> Pointer: + raise NotImplementedError("align is not supported in runtime") + + def verify(self, expected_py_type): + if expected_py_type is Pointer: + return True + elif isinstance(expected_py_type, ir.Value) and expected_py_type.ty is Pointer: + return True + + return False + + def __str__(self) -> str: + return f"Ptr<0x{int(self._pointer):016x}@{self._addr_space}>" + + def __repr__(self): + return self.__str__() + + +class _Tensor(Tensor): + def __init__( + self, + tensor, + assumed_align=None, + ): + # If tensor is already a DLPack object, use it directly + if hasattr(tensor, "__dlpack_device__") and not hasattr(tensor, "__dlpack__"): + self._dlpack_data = tensor + else: + self._dlpack_data = tensor.__dlpack__() + self._dltensor_wrapper = None + self._assumed_align = assumed_align + self._is_dynamic = False + self._memref_desc = None + self._dtype = None + + @property + def __class__(self) -> Type[Tensor]: + # Cheat to let `type(_Tensor())` to return cute.Tensor + return Tensor + + @staticmethod + def lazily_load_dltensor(func): + """Decorator to lazily load the DLTensorWrapper. + + This decorator loads the DLTensorWrapper when needed, + avoiding overhead in the critical path of calling JIT functions. + """ + + def wrapper(self, *args, **kwargs): + if self._dltensor_wrapper is None: + self._dltensor_wrapper = _cute_ir.DLTensorWrapper(self._dlpack_data) + return func(self, *args, **kwargs) + + return wrapper + + @lazily_load_dltensor + def mark_layout_dynamic(self, leading_dim: int | None = None): + """Marks the tensor layout as dynamic based on the leading dimension. + + :param leading_dim: The leading dimension of the layout, defaults to None + :type leading_dim: int, optional + + When ``leading_dim`` is None, automatically deduces the leading dimension from the tensor layout. + The layout can be deduced only when exactly one dimension has a stride of 1. Raises an error + if the layout cannot be automatically deduced. + + When ``leading_dim`` is explicitly specified, marks the layout as dynamic while setting the + stride at ``leading_dim`` to 1. Also validates that the specified ``leading_dim`` is consistent + with the existing layout by checking that the corresponding stride of that dimension is 1. + + Limitation: only support flat layout for now. Will work on supporting nested layout in the future. + + :return: The tensor with dynamic layout + :rtype: _Tensor + """ + self._dltensor_wrapper.mark_layout_dynamic(leading_dim) + return self + + @lazily_load_dltensor + def mark_compact_shape_dynamic( + self, + mode: int, + stride_order: tuple[int, ...] | None = None, + divisibility: int = 1, + ): + """Marks the tensor shape as dynamic and propagates dynamic and divisibility information to the corresponding strides. + + :param mode: The mode of the compact shape, defaults to 0 + :type mode: int + :param stride_order: Consistent with `torch.Tensor.dim_order`. Defaults to None. + Indicates the order of the modes (dimensions) if the current layout were converted to row-major order. + It starts from the outermost to the innermost dimension. + :type stride_order: tuple[int, ...], optional + :param divisibility: The divisibility constraint for the compact shape, defaults to 1 + :type divisibility: int, optional + :return: The tensor with dynamic compact shape + :rtype: _Tensor + + If ``stride_order`` is not provided, the stride ordering will be automatically deduced from the layout. + Automatic deduction is only possible when exactly one dimension has a stride of 1 (compact layout). + An error is raised if automatic deduction fails. + + If ``stride_order`` is explicitly specified, it does the consistency check with the layout. + + For example: + - Layout: (4,2):(1,4) has stride_order: (1,0) indicates the innermost dimension is 0(`4:1`), the outermost dimension is 1(`2:4`) + - Layout: (5,3,2,4):(3,1,15,30) has stride_order: (3,2,0,1) indicates the innermost dimension is 1(`3:1`), the outermost dimension is 3(`4:30`). + + Using `torch.Tensor.dim_order()` to get the stride order of the torch tensor. + .. code-block:: python + a = torch.empty(3, 4) + t = cute.runtime.from_dlpack(a) + t = t.mark_compact_shape_dynamic(mode=0, stride_order=a.dim_order()) + """ + self._dltensor_wrapper.mark_compact_shape_dynamic( + mode, stride_order, divisibility + ) + return self + + @property + @lazily_load_dltensor + def element_type(self) -> Type[Numeric]: + if self._dtype is None: + self._dtype = self._dltensor_wrapper.dtype + return self._dtype + + @element_type.setter + def element_type(self, new_type): + """Set the element type of the tensor. + + :warning: This API is added for narrow precision before we have a clean `recast_tensor` story. + + :note: It is only used for the case that frameworks don't natively support narrow precision but we get tensor + from frameworks with storage type like uint8. + + **Example**: + + .. code-block:: python + + # Create a tensor from a numpy array + import numpy as np + from cutlass.cute import from_dlpack + + # Create a tensor with Float32 elements + a = np.zeros(shape, dtype=np.uint8) + tensor = from_dlpack(a) + + # Change the element type to Float4E2M1FN even storage type is uint8 + tensor.element_type = cutlass.Float4E2M1FN + + src = from_dlpack(... data tensor ...) + # convert and initialize narrow precision tensor + cute.testing.convert(src, tensor) + """ + self._dtype = new_type + + @property + @lazily_load_dltensor + def memspace(self): + return self._dltensor_wrapper.address_space + + @property + @lazily_load_dltensor + def size_in_bytes(self) -> int: + return self._dltensor_wrapper.size_in_bytes() + + @property + @lazily_load_dltensor + def mlir_type(self) -> ir.Type: + return self._dltensor_wrapper.get_type( + self.element_type.mlir_type, self._assumed_align + ) + + @lazily_load_dltensor + def __str__(self) -> str: + return f"Tensor<0x{self._dltensor_wrapper.str}>" + + def __repr__(self): + return self.__str__() + + def __setitem__(self, crd, value): + raise TypeError(f"runtime._Tensor is not indexable") + + def __getitem__(self, crd): + raise TypeError(f"runtime._Tensor is not indexable") + + @property + @lazily_load_dltensor + def iterator(self): + return _Pointer( + self._dltensor_wrapper.data_ptr, + self.element_type, + self.memspace, + self._assumed_align, + ) + + @property + def layout(self): + raise NotImplementedError( + f"layout property is not supported in runtime, support in future" + ) + + @property + @lazily_load_dltensor + def shape(self): + return self._dltensor_wrapper.shape + + @property + @lazily_load_dltensor + def stride(self): + strides = self._dltensor_wrapper.stride + if strides is None: + strides = itertools.accumulate( + reversed(self.shape), func=operator.mul, initial=1 + ) + strides = tuple(reversed(list(strides)[:-1])) + + return strides + + @property + @lru_cache(maxsize=128, typed=True) + def leading_dim(self): + """Get the leading dimension of this Tensor. + + :return: The leading dimension index or indices + :rtype: int or tuple or None + + The return value depends on the tensor's stride pattern: + + * If a single leading dimension is found, returns an integer index + * If nested leading dimensions are found, returns a tuple of indices + * If no leading dimension is found, returns None + """ + return core.leading_dim(self.shape, self.stride) + + def fill(self, value: Numeric): + raise TypeError(f"fill function is not supported in runtime") + + @property + @lazily_load_dltensor + def data_ptr(self): + return self._dltensor_wrapper.data_ptr + + @lazily_load_dltensor + def __c_pointers__(self): + self._memref_desc = self._dltensor_wrapper.build_memref_desc( + self._assumed_align + ) + return [_cute_ir.pycapsule_get_pointer(self._memref_desc)] + + def __get_mlir_types__(self): + return [self.mlir_type] + + def __new_from_mlir_values__(self, values): + assert len(values) == 1 + assert isinstance(values[0], CoreTensor) + return CoreTensor(values[0].value, self._dtype) + + +def from_dlpack( + tensor_dlpack, + assumed_align=None, +) -> Tensor: + """Convert from tensor object supporting __dlpack__() to a CuTe Tensor. + + :param tensor_dlpack: Tensor object that supports the DLPack protocol + :type tensor_dlpack: object + :param assumed_align: Assumed alignment of the tensor (bytes), defaults to None, + if None, will use the element size bytes as the assumed alignment. + :type assumed_align: int, optional + :return: A CuTe Tensor object + :rtype: Tensor + + Examples: + .. code-block:: python + + import torch + from cutlass.cute.runtime import from_dlpack + x = torch.randn(100, 100) + y = from_dlpack(x) + y.shape + # (100, 100) + type(y) + # + """ + return _Tensor( + tensor_dlpack, + assumed_align=assumed_align, + ) + + +def make_ptr( + dtype: Type[Numeric], + value: Union[int, ctypes._Pointer], + mem_space: AddressSpace = AddressSpace.generic, + assumed_align=None, +) -> Pointer: + """Create a pointer from a memory address + + :param dtype: Data type of the pointer elements + :type dtype: Type[Numeric] + :param value: Memory address as integer or ctypes pointer + :type value: Union[int, ctypes._Pointer] + :param mem_space: Memory address space, defaults to AddressSpace.generic + :type mem_space: AddressSpace, optional + :param align_bytes: Alignment in bytes, defaults to None + :type align_bytes: int, optional + :return: A pointer object + :rtype: Pointer + + .. code-block:: python + + import numpy as np + import ctypes + + from cutlass import Float32 + from cutlass.cute.runtime import make_ptr + + # Create a numpy array + a = np.random.randn(16, 32).astype(np.float32) + + # Get pointer address as integer + ptr_address = a.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + + # Create pointer from address + y = make_ptr(cutlass.Float32, ptr_address) + + # Check properties + print(y.element_type) + print(type(y)) # + """ + # check if value is int or ctypes.POINTER + if isinstance(value, int): + address_value = value + elif isinstance(value, ctypes._Pointer): + # get address value + address_value = ctypes.cast(value, ctypes.c_void_p).value + assert address_value is not None, "Pointer address is None" + else: + raise TypeError( + f"Expect int or ctypes.POINTER for value but got {type(value)=}" + ) + + return _Pointer(address_value, dtype, mem_space, assumed_align=assumed_align) + + +class TensorAdapter: + """ + Convert a DLPack protocol supported tensor/array to a cute tensor. + """ + + def __init__(self, arg): + self._arg = from_dlpack(arg).mark_layout_dynamic() + + def __new_from_mlir_values__(self, values): + return self._arg.__new_from_mlir_values__(values) + + def __c_pointers__(self): + return self._arg.__c_pointers__() + + def __get_mlir_types__(self): + return self._arg.__get_mlir_types__() + + +# ------------------------------------------------------------------------- +# Try to register_jit_arg_adapter for TensorAdapter +# ------------------------------------------------------------------------- + +try: # Register for numpy.ndarray + import numpy + + JitArgAdapterRegistry.register_jit_arg_adapter(numpy.ndarray)(TensorAdapter) +except ImportError: + pass # silent attempt, suppress error + +try: # Register for torch.Tensor + import torch + + JitArgAdapterRegistry.register_jit_arg_adapter(torch.Tensor)(TensorAdapter) +except ImportError: + pass # silent attempt, suppress error diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..88e0da048fc951da5091bcc38a6e6c92164f6d04 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/testing.py @@ -0,0 +1,610 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import functools +import inspect +import logging +import os +from enum import Enum +from inspect import isclass +from itertools import product +from time import time +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import cuda.bindings.driver as cuda_driver +import cuda.bindings.runtime as cuda_runtime +import numpy as np + +import cutlass._mlir.ir as ir +import cutlass.base_dsl.jit_executor +import cutlass.cute as cute +from cutlass._mlir.dialects import builtin, cf, nvvm, vector +from cutlass.cute import core, nvgpu +from cutlass.cutlass_dsl import Constexpr, CuTeDSL, T, t, dsl_user_op + + +@dsl_user_op +def assert_(cond, msg=None, *, loc=None, ip=None): + cf.assert_(t.Boolean(cond).ir_value(), msg if msg else "", loc=loc, ip=ip) + + +def _maybe_recast_tensor_from_f4(src: core.Tensor, tv_layout: core.Layout): + if src.element_type.width == 4: + tv_layout = core.recast_layout(8, 4, tv_layout) + src = core.recast_tensor(src, dtype=t.Int8) + return src, tv_layout + + +def _maybe_recast_to_f4(input: core.TensorSSA, dtype: Type[core.Numeric]): + """Conditionally recasts the tensor to 4-bit type if the destination type is 4-bit. + + :param input: The input tensor to recast. + :param dtype: The target numeric type to potentially recast to. + :raises TypeError: If dtype is not a subclass of Numeric. + :return: A new tensor recast to 4-bit if dtype is 4-bit, otherwise returns self unchanged. + """ + if not isclass(dtype) or not issubclass(dtype, core.Numeric): + raise TypeError(f"dst_ty must be a type of Numeric, but got {dtype}") + + if dtype.width == 4: + recast_shape = core.recast_layout(4, 8, core.make_layout(input.shape)).shape + i4_vec = vector.bitcast( + T.vector(input.type.shape[0] * 2, T.i(4)), input.maybe_downcast() + ) + res_vect = builtin.unrealized_conversion_cast( + [T.vector(i4_vec.type.shape[0], dtype.mlir_type)], [i4_vec] + ) + return core.TensorSSA(res_vect, recast_shape, dtype) + return input + + +def _maybe_recast_from_f4(input: core.TensorSSA, src_dtype: Type[core.Numeric]): + """Conditionally recasts the tensor from 4-bit type if the source type is 4-bit. + + :param input: The input tensor to recast. + :param src_dtype: The source numeric type to potentially recast from. + :raises TypeError: If src_dtype is not a subclass of Numeric. + :return: A new tensor recast from 4-bit if src_dtype is 4-bit, otherwise returns self unchanged. + """ + if not isclass(src_dtype) or not issubclass(src_dtype, core.Numeric): + raise TypeError(f"src_ty must be a type of Numeric, but got {src_dtype}") + + if src_dtype.width == 4: + recast_shape = core.recast_layout(8, 4, core.make_layout(input.shape)).shape + i4_vec = builtin.unrealized_conversion_cast( + [T.vector(input.type.shape[0], T.i(4))], [input.maybe_downcast()] + ) + res_vect = vector.bitcast(T.vector(i4_vec.type.shape[0] // 2, T.i8()), i4_vec) + return core.TensorSSA(res_vect, recast_shape, core.Int8) + return input + + +@CuTeDSL.kernel +def _convert_kernel( + gSrc: core.Tensor, + gDst: core.Tensor, + cSrc: core.Tensor, + src_tv_layout: core.Layout, + dst_tv_layout: core.Layout, + src_shape: core.Shape, + src_ty, + dst_ty, +): + tidx = nvvm.read_ptx_sreg_tid_x(T.i32()) + bidx = nvvm.read_ptx_sreg_ctaid_x(T.i32()) + + cta_coord = (None, bidx) + # logical idx -> address + ctaSrc = gSrc[cta_coord] # (...,TileV,...) + ctaDst = gDst[cta_coord] # (...,TileV,...) + ctaCSrc = cSrc[cta_coord] # (...,TileV,...) + # print(f"ctaSrc = {ctaSrc.type}") + + # compose with CTA TV layout + # tid, vid -> address + tidfrgSrc = core.composition(ctaSrc, src_tv_layout) # (T,V) + tidfrgDst = core.composition(ctaDst, dst_tv_layout) # (T,V) + tidfrgCSrc = core.composition(ctaCSrc, src_tv_layout) # (T,V) + # print(f"tidfrgSrc = {tidfrgSrc.type}") + + # slice for threads + thr_coord = (tidx, None) + thrSrc = tidfrgSrc[thr_coord] # (V) + thrDst = tidfrgDst[thr_coord] # (V) + thrCSrc = tidfrgCSrc[thr_coord] # (V) + # print(f"thrSrc = {thrSrc.type}") + + # predicate + if core.elem_less(thrCSrc[0], src_shape): + # allocate fragments for gmem->rmem + frgSrc = core.make_fragment( + core.get(src_tv_layout, mode=[1]), gSrc.element_type + ) # (V) + frgDst = core.make_fragment( + core.get(dst_tv_layout, mode=[1]), gDst.element_type + ) # (V) + # print(f"frgSrc = {frgSrc.type}") + + # Move data to reg address space + copy_atom_load = core.make_copy_atom(nvgpu.CopyUniversalOp(), gSrc.element_type) + core.copy(copy_atom_load, thrSrc, frgSrc) + + vec_src = frgSrc.load() + vec_src = _maybe_recast_to_f4(vec_src, src_ty) + vec_dst = vec_src.to(dst_ty) + vec_dst = _maybe_recast_from_f4(vec_dst, dst_ty) + frgDst.store(vec_dst) + + # Copy the results back to c + copy_atom_stg = core.make_copy_atom(nvgpu.CopyUniversalOp(), gDst.element_type) + core.copy(copy_atom_stg, frgDst, thrDst) + + +@CuTeDSL.jit(preprocess=False) +def _convert( + src: core.Tensor, + dst: core.Tensor, + leading_mode: Constexpr, + elem_per_copy: Constexpr, +): + + # Step 1. figure proper tv_layout + src_ty = src.element_type + dst_ty = dst.element_type + + tv_layout = core.make_layout((128, elem_per_copy), stride=(elem_per_copy, 1)) + + # Step 2. maybe recast from f4 tensor + src, src_tv_layout = _maybe_recast_tensor_from_f4(src, tv_layout) + dst, dst_tv_layout = _maybe_recast_tensor_from_f4(dst, tv_layout) + src_shape = src.shape + # predicate tensor + idA = core.make_identity_tensor(src.shape) + + # Step 3. select a proper tiling pattern as (...,TileV, ...) + src_cta_tiler = [ + 1, + ] * core.rank(src.layout) + src_cta_tiler[leading_mode] = core.size(src_tv_layout) # (...,TileV,...) + dst_cta_tiler = [ + 1, + ] * core.rank(dst.layout) + dst_cta_tiler[leading_mode] = core.size(dst_tv_layout) # (...,TileV,...) + + # Step 4. partition input and output tensor by cta tiler. + gS = core.zipped_divide( + src, tuple(src_cta_tiler) + ) # ((...,TileV,...),(...,RestV,...)) + cS = core.zipped_divide( + idA, tuple(src_cta_tiler) + ) # ((...,TileV,...),(...,RestV,...)) + gD = core.zipped_divide( + dst, tuple(dst_cta_tiler) + ) # ((...,TileV,...),(...,RestV,...)) + # print(f"{gS.type=}") + + _convert_kernel( + gS, + gD, + cS, + src_tv_layout, + dst_tv_layout, + src_shape, + src_ty, + dst_ty, + ).launch( + grid=[core.size(gS, mode=[1]), 1, 1], + block=[core.size(src_tv_layout, mode=[0]), 1, 1], + ) + + +# Converts from src tensor to dst tensor, their logical shape are required to be the same. +# And when src or dst dtype is narrow precision(Float4E2M1FN/Float8E8M0FNU/Float8E4M3FN), the shape of +# their leading dimension should be 4(fp8)/8(fp4) element align. (nvgpu.cvt_fptrunc/cvt_fpext +# needs 32-bits aligned input/output) +def convert(src: core.Tensor, dst: core.Tensor): + assert len(src.shape) == len( + dst.shape + ), "Shape of src and dst tensors should be the same rank." + # find leading mode + leading_mode = [ + idx + for idx, (shape, stride) in enumerate(zip(src.shape, src.stride)) + if shape > 1 and stride == 1 + ] + if len(leading_mode) != 1: + raise ValueError(f"Leading mode should be unique, but got {leading_mode}") + leading_mode = leading_mode[0] + + elem_per_copy = 2 + + if src.element_type.width == 4 or dst.element_type.width == 4: + elem_per_copy = 8 + elif src.element_type.width == 8 or dst.element_type.width == 8: + elem_per_copy = 4 + assert ( + src.shape[leading_mode] % elem_per_copy == 0 + and dst.shape[leading_mode] % elem_per_copy == 0 + ) + _convert(src, dst, leading_mode, elem_per_copy) + + +######################################### +# Testing utilities +######################################### + + +def sample_pytest(rand_cfg=None): + """ + Decorator to randomly sample pytest parametrized tests. + rand_cfg: Tuple[int, float] - (random_seed, sample_ratio) + Sampling is disabled when: + - A specific test is selected (via -k or direct test path) + - Not running under pytest + """ + import functools + import os + import random + import sys + + import pytest + + seed, sample_ratio = rand_cfg + random.seed(seed) + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if rand_cfg is not None and "PYTEST_CURRENT_TEST" in os.environ: + # Check if test was explicitly selected like ::test_name[param1-param2-...] + if "-k" in sys.argv or any(".py::" in arg for arg in sys.argv): + # Test was explicitly selected, don't skip + return func(*args, **kwargs) + + if random.uniform(0.0, 1.0) > sample_ratio: + pytest.skip(f"Randomly skipped (sampling ratio: {sample_ratio})") + return func(*args, **kwargs) + + return wrapper + + return decorator + + +######################################### +# Benchmarking utilities +######################################### + + +class JitArguments: + """ + A type to hold both args and kwargs for passing to a kernel while benchmarking. + """ + + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +def _cuda_success( + err: Union[tuple, cuda_runtime.cudaError_t, cuda_driver.CUresult], message: str +): + """ + Helper function to check CUDA API errors. + """ + if isinstance(err, tuple): + _cuda_success(err[0], message) + elif isinstance(err, cuda_runtime.cudaError_t): + error_message = cuda_runtime.cudaGetErrorString(err)[1].decode("utf-8") + if err != cuda_runtime.cudaError_t.cudaSuccess: + raise RuntimeError(f"{message} : {error_message}") + elif isinstance(err, cuda_driver.CUresult): + if err != cuda_driver.CUresult.CUDA_SUCCESS: + error_message = cuda_driver.cuGetErrorString(err)[1].decode("utf-8") + raise RuntimeError(f"{message} : {error_message}") + else: + raise TypeError( + f"{err} is an unexpected type : it should be a cudaError_t or CUresult" + ) + + +def _does_kernel_use_stream( + kernel: Callable, stream: cuda_driver.CUstream, *args, **kwargs +): + """ + This function checks if the kernel uses the provided non-default stream. + It does this by capturing the stream and then checking if any kernels were launched. + :param kernel: The kernel to check + :type kernel: Callable + :param stream: The stream to check + :type stream: cuda_driver.CUstream + :return: True if the kernel uses the stream, False otherwise + :rtype: bool + """ + + assert int(stream) != int( + cuda_driver.CUstream_flags.CU_STREAM_DEFAULT + ), "Stream must be a non-default stream" + + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + + kernel(*args, **kwargs) + + err, graph = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Get number of nodes in warmup graph to check it matches what is expected + err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(graph) + _cuda_success(err, "Error on querying graph") + return num_nodes > 0 + + +def benchmark( + callable: Callable, + *, + warmup_iterations: int = 10, + iterations: int = 100, + stream: Optional[cuda_driver.CUstream] = None, + kernel_arguments: Optional[JitArguments] = None, + workspace_generator: Optional[Callable[[], JitArguments]] = None, + workspace_count: int = 1, + use_cuda_graphs: bool = False, +) -> float: + """Benchmarks a callable function with the specified parameters. + + For example, + .. code-block:: python + + from cutlass.cute.testing import benchmark + + @cute.jit + def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda_driver.CUstream): + # contents of the function + pass + + time_us = benchmark(user_function, kernel_arguments=JitArguments(a, b, c, stream) + warmup_iterations=10, iterations=100 + stream=stream) + + To prevent skewing results by repeately accessing the L2 cache, use the workspace_count and workspace_generator + parameters to cycle through a number of different workspaces. + + .. code-block:: python + + from cutlass.cute.testing import benchmark + + @cute.jit + def user_function(a: cute.Tensor, b: cute.Tensor, c: cute.Tensor): + # contents of the function + pass + + def workspace_generator(): + # create a, b, and c + return JitArguments(a, b, c) + + time_us = benchmark(user_function, + workspace_generator=workspace_generator, + workspace_count=10, + warmup_iterations=10000, + iterations=1000) + + To benchmark you may always configure the function being profiled (callable), the warmup iterations, and + the number of profiling iterations. + + Whenever the kernel being benchmarked runs in a non-default stream, the stream must be provided through the stream parameter. + + To use CUDA graphs, the callable must be a compiled @cute.jit annotated function. + When using CUDA graphs, the kernel must be launched in a non-default stream. + + :param callable: The function to benchmark + :type callable: Callable + :param warmup_iterations: Number of warmup iterations, defaults to 10 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations, defaults to 100 + :type iterations: int, optional + :param stream: Stream kernel is launched in, defaults to CUDA stream default + :type stream: CUstream, None + :param kernel_arguments: Kernel arguments to launch callable with, defaults to None + :type kernel_arguments: JitArguments, None + :param workspace_generator: Function that returns kernel arguments, defaults to None + :type workspace_generator: Callable + :param workspace_count: Number of workspaces (arguments) to loop through, looping through enough workspaces will keep the L2 cache cold + :type workspace_count: int, optional + :param use_cuda_graphs: Whether to use cuda graphs, defaults to False + :type use_cuda_graphs: bool, optional + + :return: The benchmark time in microseconds + :rtype: float + """ + + if stream is None: + stream = cuda_driver.CUstream(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT) + + if workspace_count < 1: + raise ValueError("workspace_count must be at least 1") + + time_us = float("nan") + if workspace_generator == None: + # If no workspace generator is provided, we need a single workspace + if workspace_count != 1: + raise ValueError("Need a single workspace if not providing a generator") + + # If no workspace generator is provided, we need a kernel_argument + if kernel_arguments == None: + raise ValueError( + "Please pass a kernel argument if not providing a generator" + ) + workspace_generator = lambda: kernel_arguments + + workspaces = [workspace_generator() for _ in range(workspace_count)] + + for workspace in workspaces: + if type(workspace) != JitArguments: + raise TypeError( + "workspace_generator and/or kernel_arguments should use JitArguments type" + ) + + def _loop_and_call_kernel(iterations: int, workspace_index: int = 0): + for _ in range(iterations): + current_workspace = workspaces[workspace_index] + callable(*current_workspace.args, **current_workspace.kwargs) + workspace_index = (workspace_index + 1) % workspace_count + return workspace_index + + # Create CUDA events for timing + err, start_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + err, end_event = cuda_driver.cuEventCreate( + cuda_driver.CUevent_flags.CU_EVENT_DEFAULT + ) + _cuda_success(err, "Error on creating event") + + elapsed_time = float("nan") + + if use_cuda_graphs: + # Check if the callable is a JitExecutor + if not isinstance(callable, cutlass.base_dsl.jit_executor.JitExecutor): + raise TypeError("Function must be precompiled to be used with CUDA Graphs") + + # Check if the stream is a non-default stream + if int(stream) == int(cuda_driver.CUstream_flags.CU_STREAM_DEFAULT): + raise ValueError( + "Measuring with CUDA Graphs requires executing in a non-default stream" + ) + + workspace_index = 0 + + # Capture warmup graph + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + + workspace_index = _loop_and_call_kernel(warmup_iterations) + err, gwarm = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Get number of nodes in warmup graph to check it matches what is expected + err, _, num_nodes = cuda_runtime.cudaGraphGetNodes(gwarm) + _cuda_success(err, "Error on querying graph") + # Assertion is >= since we may launch multiple kernels in one host function + if num_nodes < warmup_iterations: + raise ValueError( + f"CUDA stream passed to benchmark does not match the stream the kernel was launched in" + ) + + # Capture profiling graph + err = cuda_runtime.cudaStreamBeginCapture( + stream, cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + _cuda_success(err, "Error on stream capture") + _loop_and_call_kernel(iterations, workspace_index) + err, gprofile = cuda_runtime.cudaStreamEndCapture(stream) + _cuda_success(err, "Error on stream capture") + + # Instantiate graphs + err, gwarm = cuda_runtime.cudaGraphInstantiate(gwarm, 0) + _cuda_success(err, "Error on graph instantiation") + err, gprofile = cuda_runtime.cudaGraphInstantiate(gprofile, 0) + _cuda_success(err, "Error on graph instantiation") + + # Launch warmup graph + err = cuda_runtime.cudaGraphLaunch(gwarm, stream) + _cuda_success(err, "Error on graph launch") + + # Record start time + err = cuda_driver.cuEventRecord(start_event, stream) + _cuda_success(err, "Error on recording event") + + # Launch profiling graph + err = cuda_runtime.cudaGraphLaunch(gprofile, stream) + _cuda_success(err, "Error on graph launch") + + # Record end time + err = cuda_driver.cuEventRecord(end_event, stream) + _cuda_success(err, "Error on recording event") + err = cuda_driver.cuEventSynchronize(end_event) + _cuda_success(err, "Error on synchronizing event") + + # Get elapsed time + err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying event") + + # Destroy graphs + err = cuda_runtime.cudaGraphExecDestroy(gwarm) + _cuda_success(err, "Error on destroying graph") + err = cuda_runtime.cudaGraphExecDestroy(gprofile) + _cuda_success(err, "Error on destroying graph") + + else: + + if int(stream) != int( + cuda_driver.CUstream_flags.CU_STREAM_DEFAULT + ) and not _does_kernel_use_stream( + callable, stream, *workspaces[0].args, **workspaces[0].kwargs + ): + raise ValueError( + "CUDA stream passed to benchmark does not match the stream the kernel was launched in" + ) + + # Not using graphs + # Warmup + workspace_index = _loop_and_call_kernel(warmup_iterations) + # Record start event + err = cuda_driver.cuEventRecord(start_event, stream) + _cuda_success(err, "Error on recording event") + _loop_and_call_kernel(iterations, workspace_index) + # Record end event + err = cuda_driver.cuEventRecord(end_event, stream) + _cuda_success(err, "Error on recording event") + # Synchronize end event + err = cuda_driver.cuEventSynchronize(end_event) + _cuda_success(err, "Error on synchronizing event") + err, elapsed_time = cuda_driver.cuEventElapsedTime(start_event, end_event) + _cuda_success(err, "Error on querying event") + + # Destroy events + err = cuda_driver.cuEventDestroy(start_event) + _cuda_success(err, "Error on destroying event") + err = cuda_driver.cuEventDestroy(end_event) + _cuda_success(err, "Error on destroying event") + + return elapsed_time / iterations * 1e3 + + +def get_workspace_count( + one_workspace_bytes: int, warmup_iterations: int, iterations: int +) -> int: + """Calculate the number of workspaces needed to fill L2 cache. + + :param one_workspace_bytes: Size of one workspace in bytes + :type one_workspace_bytes: int + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations + :type iterations: int + :return: Number of workspaces needed + :rtype: int + """ + num_l2_cache_bytes = cutlass.utils.HardwareInfo().get_l2_cache_size_in_bytes() + return max( + 1, + min( + warmup_iterations + iterations, # Don't create more workspaces than needed + (num_l2_cache_bytes + one_workspace_bytes - 1) + // one_workspace_bytes, # Ceiling division + ), + ) + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..215e71d98fc39c192c784c99bb8ef14f6e2f55d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/cute/typing.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from abc import ABC, abstractmethod +from typing import ForwardRef, Tuple, Union, Any, Type, List + +from cutlass.base_dsl.typing import * + +from cutlass._mlir import ir +import cutlass._mlir.extras.types as T +from cutlass._mlir.dialects.cute import AddressSpace + + +Int = Union[int, Integer] + + +ScaledBasis = ForwardRef("ScaledBasis") + + +IntTuple = Union[Int, Tuple["IntTuple", ...]] +Shape = Union[Int, Tuple["Shape", ...]] +Stride = Union[Int, ScaledBasis, Tuple["Stride", ...]] +Coord = Union[Int, None, Tuple["Coord", ...]] + + +class Layout(ir.Value): + def __init__(self, op_result): + super().__init__(op_result) + + def __str__(self): ... + + def get_hier_coord(self, idx) -> Coord: + """Return the (hierarchical) ND logical coordinate corresponding to the linear index""" + ... + + @property + def shape(self, *, loc=None, ip=None) -> Shape: ... + + @property + def stride(self, *, loc=None, ip=None) -> Stride: ... + + +Tile = Union[Int, None, Layout, Tuple["Tile", ...]] + +# XTuple is super set of above types +XTuple = Union[IntTuple, Shape, Stride, Coord, Tile] + +Tiler = Union[Shape, Layout, Tile] + + +class Pointer(ABC): + """ + Abstract base class for CuTe jit function and runtime _Pointer + """ + + @property + def value_type(self) -> Type[Numeric]: + return self.dtype + + @property + def dtype(self) -> Type[Numeric]: ... + + def align(self, min_align: int) -> "Pointer": ... + + def __get_mlir_types__(self) -> List[ir.Type]: ... + + def __extract_mlir_values__(self) -> List[ir.Value]: ... + + def __new_from_mlir_values__(self, values) -> "Pointer": ... + + +class Tensor(ABC): + """ + Abstract base class for CuTe jit function and runtime _Tensor + + A CuTe Tensor is iterator with layout + + :Examples: + + Create tensor from torch.tensor with Host Runtime: + + .. code-block:: python + + >>> import torch + >>> from cutlass.cute.runtime import from_dlpack + >>> mA = from_dlpack(torch.tensor([1, 3, 5], dtype=torch.int32)) + >>> mA.shape + (3,) + >>> mA.stride + (1,) + >>> mA.layout + (3,):(1,) + + Define JIT function: + + .. code-block:: python + + @cute.jit + def add(a: Tensor, b: Tensor, res: Tensor): ... + + Call JIT function from python: + + .. code-block:: python + + >>> import torch + >>> a = torch.tensor([1, 3, 5], dtype=torch.int32) + >>> b = torch.tensor([2, 4, 6], dtype=torch.int32) + >>> c = torch.zeros([3], dtype=torch.int32) + >>> mA = from_dlpack(a) + >>> mB = from_dlpack(b) + >>> mC = from_dlpack(c) + >>> add(mA, mB, mC) + >>> c + tensor([3, 7, 11], dtype=torch.int32) + """ + + def __str__(self): ... + + @abstractmethod + def __getitem__(self, idx) -> Union["Tensor", ir.Value, IntTuple]: ... + + @abstractmethod + def __setitem__(self, idx, value): ... + + @property + @abstractmethod + def element_type(self) -> Union[Type[Numeric], Type[IntTuple]]: ... + + @element_type.setter + def element_type(self, new_type): ... + + @property + @abstractmethod + def memspace(self) -> AddressSpace: ... + + @property + @abstractmethod + def iterator(self): ... + + @property + def layout(self) -> Union[Layout, "ComposedLayout"]: ... + + @property + def shape(self) -> Shape: ... + + def load(self, *, loc=None, ip=None) -> "TensorSSA": ... + + def store(self, data: "TensorSSA", *, loc=None, ip=None): ... + + def mark_layout_dynamic(self, leading_dim: int | None = None) -> "Tensor": ... + + def mark_compact_shape_dynamic( + self, + mode: int, + stride_order: tuple[int, ...] | None = None, + divisibility: int = 1, + ) -> "Tensor": ... + + @abstractmethod + def fill(self, value: Numeric) -> None: ... + + +__all__ = [ + "Coord", + "Numeric", + "Integer", + "Boolean", + "Int8", + "Int16", + "Int32", + "Int64", + "Uint8", + "Uint16", + "Uint32", + "Uint64", + "Float", + "Float16", + "BFloat16", + "TFloat32", + "Float32", + "Float64", + "Float8E5M2", + "Float8E4M3FN", + "Float8E4M3B11FNUZ", + "Float8E4M3", + "Float8E8M0FNU", + "Float4E2M1FN", + "Float6E2M3FN", + "Float6E3M2FN", + "IntTuple", + "Layout", + "Pointer", + "Shape", + "Stride", + "Tensor", + "Tile", + "Tiler", + "XTuple", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb9b5207144a11665449fac431fcbe2bd8f49bd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/impl_utils.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + + +def check_value_in( + value, possible_values: list, value_description: str, prefix="" +) -> None: + if value not in possible_values: + err_msg = prefix + if err_msg != "": + err_msg += ": " + err_msg += f"invalid {value_description}, got {value}, must be one of {possible_values}" + raise ValueError(err_msg) + + +def check_type_in(ty, possible_types: list, type_description: str, prefix="") -> None: + if not isinstance(ty, type): + ty = type(ty) + if ty not in possible_types: + err_msg = prefix + if err_msg != "": + err_msg += ": " + err_msg += f"invalid type for {type_description}, got {ty}, must be one of {possible_types}" + raise TypeError(err_msg) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7df24dd6bb6a5e42ebf5bad0e785cf77589bbbc6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/__init__.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .helpers import ( + Agent, + CooperativeGroup, + PipelineOp, + SyncObject, + MbarrierArray, + NamedBarrier, + TmaStoreFence, + PipelineUserType, + PipelineState, + make_pipeline_state, + pipeline_init_wait, + arrive, + arrive_unaligned, + wait, + wait_unaligned, + arrive_and_wait, + sync, +) + +from .sm90 import ( + PipelineAsync, + PipelineCpAsync, + PipelineTmaAsync, + PipelineTmaMultiConsumersAsync, + PipelineTmaStore, + PipelineProducer, + PipelineConsumer, +) + +from .sm100 import ( + PipelineTmaUmma, + PipelineAsyncUmma, + PipelineUmmaAsync, +) + +__all__ = [ + "Agent", + "CooperativeGroup", + "PipelineOp", + "SyncObject", + "MbarrierArray", + "NamedBarrier", + "TmaStoreFence", + "PipelineUserType", + "PipelineState", + "PipelineAsync", + "PipelineCpAsync", + "PipelineTmaAsync", + "PipelineTmaUmma", + "PipelineTmaMultiConsumersAsync", + "PipelineAsyncUmma", + "PipelineUmmaAsync", + "PipelineTmaStore", + "PipelineProducer", + "PipelineConsumer", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b94899435224ceda4bd152944e9a4b9bc2e911 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/helpers.py @@ -0,0 +1,652 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, Int32, Int64, if_generate +from cutlass._mlir.dialects import llvm +import cutlass._mlir.dialects.cute as _cute_ir + + +############################################################################## +# Agent class +############################################################################## + + +class Agent(enum.Enum): + """ + Agent indicates what is participating in the pipeline synchronization. + """ + + # Arbitrary grouping of N threads + Thread = enum.auto() + # Same as AsyncThread, but includes all threads in the block + ThreadBlock = enum.auto() + # Same as AsyncThread, but includes all threads in the cluster + ThreadBlockCluster = enum.auto() + + +class CooperativeGroup: + """ + CooperativeGroup contains size and alignment restrictions for an Agent. + """ + + def __init__(self, agent: Agent, size: int = 1, alignment: int = 1): + if agent is Agent.Thread: + assert size > 0 + if size == 32: + assert ( + size == alignment + ), "Error: Alignment does not match number of threads in a warp." + elif size == 128: + assert ( + size == alignment + ), "Error: Alignment does not match number of threads in a warpgroup." + elif agent is Agent.ThreadBlock: + raise NotImplementedError("Error: Not yet supported.") + elif agent is Agent.ThreadBlockCluster: + raise NotImplementedError("Error: Not yet supported.") + else: + # Should never reach this state + size = 0 + + if size <= 0: + raise ValueError( + "Error: The number of threads in a CooperativeGroup must be more than 0." + ) + + # Size indicates how many threads are participating in this CooperativeGroup + self.size = size + # Agent indicates the type of thread group + self.agent = agent + + +class PipelineOp(enum.Enum): + """ + PipelineOp assigns an operation to an agent corresponding to a specific hardware feature. + """ + + # async-threads + AsyncThread = enum.auto() + # Blackwell (SM100a) MMA instruction + TCGen05Mma = enum.auto() + # Tensor Memory Accelerator load + TmaLoad = enum.auto() + # TMA Store consuming smem produced by AsyncThread + TmaStore = enum.auto() + # Composite of multiple PipelineOps + Composite = enum.auto() + # Async load without TMA + AsyncLoad = enum.auto() + + +def _get_pipeline_op(type_str): + return PipelineOp(type_str) + + +############################################################################## +# SyncObject class +############################################################################## + + +class SyncObject(ABC): + """Abstract base class for hardware synchronization primitives. + + This class defines the interface for different types of hardware synchronization + mechanisms including shared memory barriers, named barriers, and fences. + """ + + @abstractmethod + def arrive(self) -> None: + pass + + @abstractmethod + def wait(self) -> None: + pass + + @abstractmethod + def arrive_and_wait(self) -> None: + pass + + @abstractmethod + def arrive_and_drop(self) -> None: + pass + + @abstractmethod + def get_barrier(self) -> Union[cute.Pointer, int, None]: + pass + + @abstractmethod + def max(self) -> Union[int, None]: + pass + + +class MbarrierArray(SyncObject): + """ + MbarrierArray implements an abstraction for an array of smem barriers. + """ + + def __init__( + self, + barrier_storage: cute.Pointer, + num_stages: int, + agent: tuple[PipelineOp, CooperativeGroup], + tx_count: int = 0, + ) -> None: + self.barrier_storage = barrier_storage + self.tx_count = tx_count + self.num_stages = num_stages + self.op_type, self.cg = agent + self.arrive_count = self.cg.size + + if self.num_stages <= 0: + raise ValueError("Error: Mbarrier stage count must be greater than 0.") + if self.arrive_count <= 0: + raise ValueError("Error: Mbarrier arrive count must be greater than 0.") + if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0: + raise ValueError( + "Error: Mbarrier tx count must not be less than 0 for TMA ops." + ) + + # Store mbarrier base pointer + self.mbarrier_base = self.barrier_storage + + # Mbarrier initialization in constructor + self.mbarrier_init() + + def recast_to_new_op_type(self, new_op_type: PipelineOp) -> "MbarrierArray": + """ + Creates a copy of MbarrierArray with a different op_type without re-initializing barriers + """ + # Create new instance without initialization + new_mbarrier_array = object.__new__(MbarrierArray) + + # Copy all attributes directly + new_mbarrier_array.barrier_storage = self.barrier_storage + new_mbarrier_array.op_type = new_op_type + new_mbarrier_array.cg = self.cg + new_mbarrier_array.num_stages = self.num_stages + new_mbarrier_array.tx_count = self.tx_count + new_mbarrier_array.arrive_count = self.arrive_count + new_mbarrier_array.mbarrier_base = self.mbarrier_base + return new_mbarrier_array + + # Mbarrier initialization + def mbarrier_init(self) -> None: + """ + Initializes an array of mbarriers using warp 0. + """ + + def then_body(): + for index in range(self.num_stages): + cute.arch.mbarrier_init(self.get_barrier(index), self.arrive_count) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + if_generate(warp_idx == 0, then_body) + + def arrive( + self, + index: int, + dst: int, + cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, + ) -> None: + """Select the arrive corresponding to this MbarrierArray's PipelineOp. + + :param index: Index of the mbarrier in the array to arrive on + :type index: int + :param dst: Destination parameter for selective arrival, which can be either a mask or destination cta rank. + When None, both ``TCGen05Mma`` and ``AsyncThread`` will arrive on their local mbarrier. + - For ``TCGen05Mma``, ``dst`` serves as a multicast mask (e.g., 0b1011 allows arrive signal to be multicast to CTAs + in the cluster with rank = 0, 1, and 3). + - For ``AsyncThread``, ``dst`` serves as a destination cta rank (e.g., 3 means threads will arrive on + the mbarrier with rank = 3 in the cluster). + :type dst: int | None + :param cta_group: CTA group for ``TCGen05Mma``, defaults to None for other op types + :type cta_group: ``cute.nvgpu.tcgen05.CtaGroup``, optional + """ + if self.op_type is PipelineOp.AsyncThread: + self.arrive_mbarrier(index, dst) + elif self.op_type is PipelineOp.TCGen05Mma: + assert ( + cta_group is not None + ), "Error: CTA group must be provided for TCGen05Mma." + self.arrive_tcgen05mma(index, dst, cta_group) + elif self.op_type in [PipelineOp.TmaLoad]: + self.arrive_and_expect_tx(index, self.tx_count) + elif self.op_type is PipelineOp.AsyncLoad: + self.arrive_cp_async_mbarrier(index) + else: + assert ( + False + ), f"Error: MbarrierArray is not supported for PipelineOp: {_get_pipeline_op(self.op_type)}." + + def arrive_mbarrier(self, index: int, dst_rank: Optional[int] = None) -> None: + if dst_rank is None: + cute.arch.mbarrier_arrive(self.get_barrier(index)) + else: + cute.arch.mbarrier_arrive(self.get_barrier(index), dst_rank) + + def arrive_cp_async_mbarrier(self, index: int): + cute.arch.cp_async_mbarrier_arrive_noinc(self.get_barrier(index)) + + def arrive_tcgen05mma( + self, index: int, mask: Optional[int], cta_group: cute.nvgpu.tcgen05.CtaGroup + ) -> None: + if mask is None: + with cute.arch.elect_one(): + cute.nvgpu.tcgen05.commit(self.get_barrier(index)) + else: + with cute.arch.elect_one(): + cute.nvgpu.tcgen05.commit(self.get_barrier(index), mask, cta_group) + + def arrive_and_expect_tx(self, index: int, tx_count: int) -> None: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx(self.get_barrier(index), tx_count) + + def try_wait(self, index: int, phase: int) -> Boolean: + return cute.arch.mbarrier_try_wait(self.get_barrier(index), phase) + + def wait(self, index: int, phase: int) -> None: + cute.arch.mbarrier_wait(self.get_barrier(index), phase) + + def arrive_and_wait( + self, + index: int, + phase: int, + dst: int, + cta_group: Optional[cute.nvgpu.tcgen05.CtaGroup] = None, + ) -> None: + arrive(index, dst, cta_group) + wait(index, phase) + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not yet supported.") + + def get_barrier(self, index: int) -> cute.Pointer: + return self.mbarrier_base + index + + def max(self) -> int: + # Transaction barriers have a maximum arrive count of 511 (2^9 - 1). + # Non-transaction barriers have a maximum arrive count of 1,048,575 (2^20 - 1). + return 511 + + def __extract_mlir_values__(self): + return [self.barrier_storage] + + def __new_from_mlir_values__(self, values): + return MbarrierArray( + values[0], self.num_stages, (self.op_type, self.cg), self.tx_count + ) + + +@dataclass(frozen=True) +class NamedBarrier(SyncObject): + """ + NamedBarrier is an abstraction for named barriers managed by hardware. + There are 16 named barriers available, with barrier_ids 0-15. + + See the `PTX documentation `__. + """ + + barrier_id: int + num_threads: int + + def __post_init__(self) -> None: + if self.barrier_id < 0 or self.barrier_id >= 16: + raise ValueError("Error: NamedBarrier ID must be between 0 and 16.") + if self.barrier_id == 0: + warnings.warn( + "NamedBarrier ID 0 is by other driver APIs (i.e. sync_threads()) and should not be used." + ) + + def arrive(self) -> None: + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive( + barrier_id=self.barrier_id, number_of_threads=self.num_threads + ) + + def arrive_unaligned(self) -> None: + """ + The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. + """ + llvm.inline_asm( + None, + [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], + "barrier.arrive $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def wait(self) -> None: + """ + NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. + If synchronizing two warps in a producer/consumer pairing, the arrive count would be + 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer + or consumer are counted for mbarriers, while all threads participating in the sync + are counted for NamedBarriers. + """ + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + self.arrive_and_wait() + + def wait_unaligned(self) -> None: + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + llvm.inline_asm( + None, + [Int32(self.barrier_id).ir_value(), Int32(self.num_threads).ir_value()], + "barrier.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + def arrive_and_wait(self) -> None: + cute.arch.barrier( + barrier_id=self.barrier_id, number_of_threads=self.num_threads + ) + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not supported.") + + def sync(self) -> None: + cute.arch.barrier(barrier_id=self.barrier_id) + + def get_barrier(self) -> int: + return self.barrier_id + + def max(self) -> int: + # Transaction barriers have a maximum arrive count of 4095 (2^12 - 1). + return 4095 + + +class TmaStoreFence(SyncObject): + """ + TmaStoreFence is used for a multi-stage epilogue buffer. + """ + + def __init__(self, num_stages: int = 0) -> None: + if num_stages <= 0: + raise ValueError("Mbarrier stage count must be greater than 0.") + + self.num_stages = num_stages + + def arrive(self) -> None: + cute.arch.cp_async_bulk_commit_group() + + def wait(self) -> None: + cute.arch.cp_async_bulk_wait_group(self.num_stages - 1, read=True) + + def arrive_and_wait(self) -> None: + self.arrive() + self.wait() + + def arrive_and_drop(self) -> None: + raise NotImplementedError("Error: Not supported.") + + # TmaStoreFence doesn't have mbarriers + def get_barrier(self) -> None: + assert ( + False + ), "Error: TmaStoreFence doesn't use mbarriers and cannot return a barrier." + + def max(self) -> None: + raise NotImplementedError("Error: Not supported.") + + def tail(self) -> None: + cute.arch.cp_async_bulk_wait_group(0, read=True) + + +############################################################################## +# PipelineState class +############################################################################## + + +class PipelineUserType(enum.Enum): + Producer = enum.auto() + Consumer = enum.auto() + + +class PipelineState: + """ + Pipeline state contains an index and phase bit corresponding to the current position in the circular buffer. + """ + + def __init__(self, stages: int, count, index, phase): + self._stages = stages + self._count = count + self._index = index + self._phase = phase + + def clone(self) -> "PipelineState": + return PipelineState(self.stages, self._count, self.index, self.phase) + + @property + def index(self) -> Int32: + return self._index + + @property + def count(self) -> Int32: + return self._count + + @property + def stages(self) -> int: + return self._stages + + @property + def phase(self) -> Int32: + return self._phase + + def reset_count(self): + self._count = Int32(0) + + def advance(self): + self._index += 1 + self._count += 1 + + def then_body(index, phase): + new_index = Int32(0) + new_phase = phase ^ 1 + return new_index, new_phase + + def else_body(index, phase): + return index, phase + + self._index, self._phase = if_generate( + self._index == self.stages, + then_body, + else_body, + [self.index, self.phase], + [Int32, Int32], + ) + + def reverse(self): + self._index -= 1 + self._count -= 1 + + def then_body(index, phase): + new_index = Int32(self.stages - 1) + new_phase = phase ^ 1 + return new_index, new_phase + + def else_body(index, phase): + return index, phase + + self._index, self._phase = if_generate( + self._index == -1, + then_body, + else_body, + [self.index, self.phase], + [Int32, Int32], + ) + + def __get_mlir_types__(self): + return [self._count.type, self._index.type, self._phase.type] + + def __extract_mlir_values__(self): + count = self._count + index = self._index + phase = self._phase + return [count.ir_value(), index.ir_value(), phase.ir_value()] + + # This can be overridden by derived classes + def __new_from_mlir_values__(self, values): + return PipelineState( + self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2]) + ) + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """ + Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. + """ + if type is PipelineUserType.Producer: + return PipelineState( + stages, + Int32(0), + Int32(0), + Int32(1), + ) + elif type is PipelineUserType.Consumer: + return PipelineState( + stages, + Int32(0), + Int32(0), + Int32(0), + ) + else: + assert ( + False + ), "Error: invalid PipelineUserType specified for make_pipeline_state." + + +############################################################################## +# Helper functions +############################################################################## + + +def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): + """ + Fences the mbarrier init and syncs the threadblock or cluster + """ + cute.arch.mbarrier_init_fence() + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # If not using clusters, sync the threadblock + _sync(Agent.ThreadBlock) + else: + # If using clusters, sync the cluster + _sync(Agent.ThreadBlockCluster) + + +def _sync(group: Agent): + """ + Syncs all threads within an agent. + """ + if group is Agent.Thread: + raise NotImplementedError("Error: Not supported.") + elif group is Agent.ThreadBlock: + cute.arch.sync_threads() + elif group is Agent.ThreadBlockCluster: + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + else: + assert ( + False + ), "Error: No explicit sync instruction exists. Please use barriers (named / mbarrier) instead." + + +def _mbarrier_i64_to_ptr(val: Int64) -> cute.Pointer: + """ + Converts a smem pointer of type Int64 to cute.Pointer with 8B alignment + """ + return cute.make_ptr( + Int64, + val.ir_value(), + mem_space=_cute_ir.AddressSpace.smem, + assumed_align=8, + ) + + +# NamedBarrier free functions +def arrive(barrier_id: int, num_threads: int): + """ + The aligned flavor of arrive is used when all threads in the CTA will execute the + same instruction. See PTX documentation. + """ + cute.arch.barrier_arrive(barrier_id=barrier_id, number_of_threads=num_threads) + + +def arrive_unaligned(barrier_id: int, num_threads: int): + """ + The unaligned flavor of arrive can be used with an arbitrary number of threads in the CTA. + """ + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], + "barrier.arrive $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def wait(barrier_id: int, num_threads: int): + """ + NamedBarriers do not have a standalone wait like mbarriers, only an arrive_and_wait. + If synchronizing two warps in a producer/consumer pairing, the arrive count would be + 32 using mbarriers but 64 using NamedBarriers. Only threads from either the producer + or consumer are counted for mbarriers, while all threads participating in the sync + are counted for NamedBarriers. + """ + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + arrive_and_wait() + + +def wait_unaligned(barrier_id: int, num_threads: int): + warnings.warn( + "NamedBarrier wait also arrives on the barrier. Routing call to NamedBarrier.arrive_and_wait()." + ) + llvm.inline_asm( + None, + [Int32(barrier_id).ir_value(), Int32(num_threads).ir_value()], + "barrier.sync $0, $1;", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +def arrive_and_wait(barrier_id: int, num_threads: int): + cute.arch.barrier(barrier_id=barrier_id, number_of_threads=num_threads) + + +def sync(barrier_id: int = 0): + cute.arch.barrier(barrier_id=barrier_id) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py new file mode 100644 index 0000000000000000000000000000000000000000..2feed8cc0f1e702557f0c2b21b7582651a6405b8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm100.py @@ -0,0 +1,453 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, if_generate + +from cutlass.pipeline import ( + Agent, + CooperativeGroup, + PipelineOp, + PipelineState, + pipeline_init_wait, + PipelineAsync, +) + +############################################################################## +# Pipeline classes +############################################################################## + + +@dataclass(frozen=True) +class PipelineTmaUmma(PipelineAsync): + """ + PipelineTmaUmma is used for TMA producers and UMMA consumers (e.g. Blackwell mainloops). + """ + + is_leader_cta: bool + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_mcast_arrival_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask for signaling arrivals to multicasting threadblocks. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + tma_mcast_mask_a = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=2 + ) + tma_mcast_mask_b = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=1 + ) + + block_in_cluster_coord_vmnk_peer = ( + cta_in_cluster_coord_vmnk[0] ^ 1, + *cta_in_cluster_coord_vmnk[1:], + ) + tma_mcast_mask_a_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 + ) + tma_mcast_mask_b_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 + ) + + return ( + tma_mcast_mask_a + | tma_mcast_mask_b + | tma_mcast_mask_a_peer + | tma_mcast_mask_b_peer + ) + + @staticmethod + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): + """ + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. + """ + bidx, bidy, _ = cute.arch.block_idx() + + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), + bidx // cute.size(cta_layout_vmnk, mode=[0]), + bidy, + None, + ) + return mma_coord_vmnk[0] == 0 + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # No mcast mask if not using clusters + producer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + else: + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk) + is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + consumer_mask = producer_mask + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + +@dataclass(frozen=True) +class PipelineAsyncUmma(PipelineAsync): + """ + PipelineAsyncUmma is used for AsyncThread producers and UMMA consumers (e.g. Blackwell input fusion pipelines). + """ + + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_leading_cta_rank(cta_v_size): + """ + Computes the leading CTA rank. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + return cta_rank_in_cluster // cta_v_size * cta_v_size + + @staticmethod + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): + """ + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. + """ + bidx, bidy, _ = cute.arch.block_idx() + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), + bidx // cute.size(cta_layout_vmnk, mode=[0]), + bidy, + None, + ) + return mma_coord_vmnk[0] == 0 + + @staticmethod + def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask for signaling arrivals to multicasting threadblocks. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0 + ) + block_in_cluster_coord_vmnk_peer = ( + cta_in_cluster_coord_vmnk[0] ^ 1, + *cta_in_cluster_coord_vmnk[1:], + ) + mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0 + ) + return mask_self | mask_peer + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineAsyncUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncThread + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), + num_stages, + producer, + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + cta_v_size = ( + cute.size(cta_layout_vmnk, mode=[0]) if cta_layout_vmnk is not None else 1 + ) + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: + # No mcast mask if we're not using 2CTA tcgen05 MMA + producer_mask = None + consumer_mask = None + else: + # If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA + # We need to get the target cta_rank + producer_mask = PipelineAsyncUmma._compute_leading_cta_rank(cta_v_size) + # consumer needs to get the mask to signal + consumer_mask = PipelineAsyncUmma._compute_peer_cta_mask(cta_layout_vmnk) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineAsyncUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(state.index, self.consumer_mask, self.cta_group) + + +@dataclass(frozen=True) +class PipelineUmmaAsync(PipelineAsync): + """ + PipelineUmmaAsync is used for UMMA producers and AsyncThread consumers (e.g. Blackwell accumulator pipelines). + """ + + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_tmem_sync_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask to signal completion of tmem buffers for 2CTA kernels. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + return cute.make_layout_image_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mode=0 + ) + + @staticmethod + def _compute_peer_cta_rank(): + """ + Computes a mask to signal release of tmem buffers for 2CTA kernels. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + return cta_rank_in_cluster // 2 * 2 + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineUmmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TCGen05Mma + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # Set mask to None if not using clusters (i.e. 1CTA kernels) + producer_mask = None + else: + producer_mask = PipelineUmmaAsync._compute_tmem_sync_mask(cta_layout_vmnk) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: + # Set mask to None if not using 2CTA intructions + consumer_mask = None + else: + consumer_mask = PipelineUmmaAsync._compute_peer_cta_rank() + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineUmmaAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + cta_group, + ) + + def producer_commit(self, state: PipelineState): + """ + UMMA producer commit buffer full, cta_group needs to be provided. + """ + self.sync_object_full.arrive(state.index, self.producer_mask, self.cta_group) + + def producer_tail(self, state: PipelineState): + """ + Make sure the last used buffer empty signal is visible to producer. + Producer tail is usually executed by producer before exit, to avoid dangling + mbarrier arrive signals after kernel exit. + + :param state: The pipeline state that points to next useful buffer + :type state: PipelineState + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + + def then_body(): + # Assume state contains that next useful buffer + # So we only need to advance to num_stages - 1 times to last used buffer + for i in range(self.num_stages - 1): + state.advance() + self.producer_acquire(state) + + if_generate(is_leader_cta, then_body) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc19960c9b1ccca84dcc18bca002e2fa2a303ca --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -0,0 +1,985 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import enum +from typing import Type, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Union +import warnings + +import cutlass +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, Int32, if_generate + +from cutlass.pipeline import ( + Agent, + CooperativeGroup, + PipelineOp, + SyncObject, + MbarrierArray, + TmaStoreFence, + PipelineUserType, + PipelineState, + make_pipeline_state, + pipeline_init_wait, +) + +############################################################################## +# Pipeline classes +############################################################################## + + +@dataclass(frozen=True) +class PipelineAsync: + """PipelineAsync is a generic pipeline class where both the producer and consumer are + AsyncThreads. It also serves as a base class for specialized pipeline classes. + + This class implements a producer-consumer pipeline pattern where both sides operate + asynchronously. The pipeline maintains synchronization state using barrier objects + to coordinate between producer and consumer threads. + + The pipeline state transitions of one pipeline entry(mbarrier) can be represented as: + + .. table:: Pipeline State Transitions + :widths: auto + + +-----------+-----------+-----------+-----------+-----------+-----------+ + | Barrier | State | p.acquire | p.commit | c.wait | c.release | + +===========+===========+===========+===========+===========+===========+ + | empty_bar | empty | | n/a | n/a | - | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | empty_bar | wait | | n/a | n/a | -> empty | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | full_bar | wait | n/a | -> full | | n/a | + +-----------+-----------+-----------+-----------+-----------+-----------+ + | full_bar | full | n/a | - | | n/a | + +-----------+-----------+-----------+-----------+-----------+-----------+ + + Where: + + - p: producer + - c: consumer + - : This action is blocked until transition to a state allow it to proceed by other side + - e.g. ``p.acquire()`` is blocked until ``empty_bar`` transition to ``empty`` state by ``c.release()`` + + .. code-block:: text + + Array of mbarriers as circular buffer: + + Advance Direction + <------------------- + + Producer Consumer + | ^ + V | + +-----------------+ + --|X|X|W|D|D|D|D|R|X|<-. + / +-----------------+ \\ + | | + `------------------------' + + Where: + + - X: Empty buffer (initial state) + - W: Producer writing (producer is waiting for buffer to be empty) + - D: Data ready (producer has written data to buffer) + - R: Consumer reading (consumer is consuming data from buffer) + + **Example:** + + .. code-block:: python + + # Create pipeline with 5 stages + pipeline = PipelineAsync.create( + num_stages=5, # number of pipeline stages + producer_group=producer_warp, + consumer_group=consumer_warp + barrier_storage=smem_ptr, # smem pointer for array of mbarriers in shared memory + ) + + producer, consumer = pipeline.make_participants() + # Producer side + for i in range(num_iterations): + handle = producer.acquire_and_advance() # Wait for buffer to be empty & Move index to next stage + # Write data to pipeline buffer + handle.commit() # Signal buffer is full + + # Consumer side + for i in range(num_iterations): + handle = consumer.wait_and_advance() # Wait for buffer to be full & Move index to next stage + # Read data from pipeline buffer + handle.release() # Signal buffer is empty + """ + + sync_object_full: SyncObject + sync_object_empty: SyncObject + num_stages: int + producer_mask: Optional[Int32] + consumer_mask: Optional[Int32] + + @staticmethod + def _make_sync_object( + barrier_storage: cute.Pointer, + num_stages: int, + agent: tuple[PipelineOp, CooperativeGroup], + tx_count: int = 0, + ) -> SyncObject: + """ + Returns a SyncObject corresponding to an agent's PipelineOp. + """ + if agent[0] in [ + PipelineOp.AsyncThread, + PipelineOp.TmaLoad, + PipelineOp.TCGen05Mma, + PipelineOp.Composite, + PipelineOp.AsyncLoad, + ]: + return MbarrierArray( + barrier_storage=barrier_storage, + num_stages=num_stages, + agent=agent, + tx_count=tx_count, + ) + elif agent[0] is PipelineOp.TmaStore: + # Path taken for AsyncTmaStore + return TmaStoreFence(num_stages=num_stages) + else: + assert False, "Error: Invalid PipelineOp specified." + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + producer_mask: Int32 = None, + consumer_mask: Int32 = None, + ): + """Creates and initializes a new PipelineAsync instance. + + This helper function computes necessary attributes and returns an instance of PipelineAsync + with the specified configuration for producer and consumer synchronization. + + :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: int + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param producer_mask: Mask for signaling arrives for the producer agent, defaults to ``None`` + :type producer_mask: Int32, optional + :param consumer_mask: Mask for signaling arrives for the consumer agent, defaults to ``None`` + :type consumer_mask: Int32, optional + :return: A new PipelineAsync instance + :rtype: PipelineAsync + :raises ValueError: If barrier_storage is not a cute.Pointer instance + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncThread + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + pipeline_init_wait() + + return PipelineAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + + def producer_try_acquire(self, state: PipelineState): + return self.sync_object_empty.try_wait(state.index, state.phase) + + def producer_commit(self, state: PipelineState): + self.sync_object_full.arrive(state.index, self.producer_mask) + + def consumer_wait( + self, state: PipelineState, try_wait_token: Optional[Boolean] = None + ): + if_generate( + try_wait_token is None or try_wait_token == 0, + lambda: self.sync_object_full.wait(state.index, state.phase), + ) + + def consumer_try_wait(self, state: PipelineState): + return self.sync_object_full.try_wait(state.index, state.phase) + + def consumer_release(self, state: PipelineState): + self.sync_object_empty.arrive(state.index, self.consumer_mask) + + def producer_get_barrier(self, state: PipelineState) -> cute.Pointer: + return self.sync_object_full.get_barrier(state.index) + + def producer_tail(self, state: PipelineState): + """ + Make sure the last used buffer empty signal is visible to producer. + Producer tail is usually executed by producer before exit, to avoid dangling + mbarrier arrive signals after kernel exit. + + :param state: The pipeline state that points to next useful buffer + :type state: PipelineState + """ + # Assume state contains that next useful buffer + # So we only need to advance to num_stages - 1 times to last used buffer + for i in range(self.num_stages - 1): + state.advance() + self.producer_acquire(state) + + # Util methods to manage produer and consumer + def make_producer(self): + state = make_pipeline_state(PipelineUserType.Producer, self.num_stages) + return PipelineProducer(self, state, self.sync_object_full.cg) + + def make_consumer(self): + state = make_pipeline_state(PipelineUserType.Consumer, self.num_stages) + return PipelineConsumer(self, state, self.sync_object_empty.cg) + + def make_participants(self): + return self.make_producer(), self.make_consumer() + + + +@dataclass(frozen=True) +class PipelineCpAsync(PipelineAsync): + """ + PipelineCpAsync is used for CpAsync producers and AsyncThread consumers (e.g. Hopper non-TMA mainloops). + """ + + @staticmethod + def create( + barrier_storage: cute.Pointer, + num_stages: Int32, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + producer_mask: Int32 = None, + consumer_mask: Int32 = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param producer_mask: Mask for signaling arrives for the producer agent + :type producer_mask: Int32 | None + :param consumer_mask: Mask for signaling arrives for the consumer agent + :type consumer_mask: Int32 | None + """ + producer_type = PipelineOp.AsyncLoad + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_array_full = PipelineCpAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer + ) + sync_object_array_empty = PipelineCpAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + pipeline_init_wait() + + return PipelineCpAsync( + sync_object_array_full, + sync_object_array_empty, + num_stages, + producer_mask, + consumer_mask, + ) + + +@dataclass(frozen=True) +class PipelineTmaAsync(PipelineAsync): + """ + PipelineTmaAsync is used for TMA producers and AsyncThread consumers (e.g. Hopper mainloops). + """ + + is_signalling_thread: Boolean + + @staticmethod + @cute.jit + def init_empty_barrier_arrive_signal(cta_layout_vmnk: cute.Layout, tidx: Int32): + """ + Initialize the empty barrier arrive signal + This function returns the destination cta rank and a boolean indicating if the signalling thread is the same as the current thread + """ + # Logic to optimally schedule Empty Arrives + cluster_shape_vmnk = cta_layout_vmnk.shape + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + + tidx = tidx % 32 + is_signalling_thread = tidx < cute.size(cluster_shape_vmnk) + dst_rank = tidx % cute.size(cluster_shape_vmnk) + + dst_cta_coord = cta_layout_vmnk.get_hier_coord(dst_rank) + cur_cta_coord = cta_layout_vmnk.get_hier_coord(cta_rank_in_cluster) + + is_same_row = ( + dst_cta_coord[0] == cur_cta_coord[0] + and dst_cta_coord[1] == cur_cta_coord[1] + and dst_cta_coord[3] == cur_cta_coord[3] + ) + is_same_col = ( + dst_cta_coord[0] == cur_cta_coord[0] + and dst_cta_coord[2] == cur_cta_coord[2] + and dst_cta_coord[3] == cur_cta_coord[3] + ) + + is_same_row_or_col = is_same_row or is_same_col + is_signalling_thread_final = is_signalling_thread and is_same_row_or_col + + return dst_rank, is_signalling_thread_final + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + tidx: Optional[Int32] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param tidx: thread index to consumer async threads + :type tidx: Int32 | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + if tidx is None: + tidx, _, _ = cute.arch.thread_idx() + if cta_layout_vmnk is None: + cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + ( + dst_rank, + is_signalling_thread, + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + dst_rank = None + else: + dst_rank = dst_rank + + producer_mask = None + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + dst_rank, + is_signalling_thread, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + self.sync_object_full.arrive(state.index, self.producer_mask) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + def consumer_release(self, state: PipelineState): + """ + TMA consumer release conditionally signals the empty buffer to the producer. + """ + if_generate( + self.is_signalling_thread, + lambda: self.sync_object_empty.arrive(state.index, self.consumer_mask), + ) + + +@dataclass(frozen=True) +class PipelineTmaMultiConsumersAsync(PipelineAsync): + """ + PipelineTmaMultiConsumersAsync is used for TMA producers and UMMA+Async consumers. + """ + + is_leader_cta: bool + sync_object_empty_umma: SyncObject + sync_object_empty_async: SyncObject + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group_umma: CooperativeGroup, + consumer_group_async: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaMultiConsumersAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group_umma: `CooperativeGroup` for the UMMA consumer agent + :type consumer_group_umma: CooperativeGroup + :param consumer_group_async: `CooperativeGroup` for the AsyncThread consumer agent + :type consumer_group_async: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.Composite + consumer_type_umma = PipelineOp.TCGen05Mma + consumer_type_async = PipelineOp.AsyncThread + + if consumer_group_umma.agent != consumer_group_async.agent: + raise ValueError( + "UMMA and AsyncThread consumer groups must be the same agent" + ) + + if cta_layout_vmnk is not None and cute.size(cta_layout_vmnk) != 1: + raise ValueError( + f"PipelineTmaMultiConsumersAsync is not verified for cta_layout_vmnk != 1, cta_layout_vmnk:{cta_layout_vmnk}" + ) + + consumer_group = CooperativeGroup( + consumer_group_umma.agent, + consumer_group_umma.size + consumer_group_async.size, + ) + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + sync_object_empty_umma = sync_object_empty.recast_to_new_op_type( + consumer_type_umma + ) + sync_object_empty_async = sync_object_empty.recast_to_new_op_type( + consumer_type_async + ) + + # No mcast mask if not using clusters + producer_mask = None + consumer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaMultiConsumersAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + sync_object_empty_umma, + sync_object_empty_async, + cta_group, + ) + + def producer_acquire( + self, state: PipelineState, try_acquire_token: Optional[Boolean] = None + ): + """ + TMA producer acquire waits on buffer empty and sets the transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase), + ) + if_generate( + self.is_leader_cta, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask), + ) + + def producer_commit(self, state: PipelineState): + """ + TMA producer commit is a noop since TMA instruction itself updates the transaction count. + """ + pass + + def consumer_release(self, state: PipelineState, op_type: PipelineOp): + if op_type == PipelineOp.TCGen05Mma: + self.sync_object_empty_umma.arrive( + state.index, self.consumer_mask, self.cta_group + ) + elif op_type == PipelineOp.AsyncThread: + self.sync_object_empty_async.arrive(state.index, self.consumer_mask) + else: + raise ValueError(f"Invalid PipelineOp specified. op_type:{op_type}") + + +@dataclass(frozen=True) +class PipelineTmaStore(PipelineAsync): + """ + PipelineTmaStore is used for synchronizing TMA stores in the epilogue. It does not use mbarriers. + """ + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaStore. + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + """ + + producer_type = PipelineOp.TmaStore + + producer = (producer_type, producer_group) + + sync_object_full = PipelineAsync._make_sync_object(None, num_stages, producer) + + return PipelineTmaStore(sync_object_full, None, num_stages, None, None) + + def producer_acquire(self): + self.sync_object_full.wait() + + def producer_commit(self): + self.sync_object_full.arrive() + + def consumer_wait(self): + assert False, "Error: PipelineTmaStore does not have a consumer agent." + + def consumer_release(self): + assert False, "Error: PipelineTmaStore does not have a consumer agent." + + def producer_tail(self): + self.sync_object_full.tail() + + +################################################################# +# Utilities to help user of pipeline to simplify the workflow +################################################################# + + +class ImmutableResourceHandle: + __origin: PipelineAsync + __immutable_state: PipelineState + + def __init__(self, origin: PipelineAsync, immutable_state: PipelineState): + self.__origin = origin + self.__immutable_state = immutable_state + + @property + def index(self): + """Get the index of the current pipeline stage.""" + return self.__immutable_state.index + + @property + def count(self): + """Get the count of how many handles this producer has committed. + This is useful for tracking the number of blocks that have been loaded from gmem. + """ + return self.__immutable_state.count + + def get_origin(self): + """Get the original pipeline this resource handle belongs to.""" + return self.__origin + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + # TODO: need to handle pipeline as well + return self.__immutable_state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Producer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Producer instance with state initialized from values + :rtype: Producer + """ + return self.__class__( + self.__origin, self.__immutable_state.__new_from_mlir_values__(values) + ) + +class PipelineProducer: + """A class representing a producer in an asynchronous pipeline. + + The Producer class manages the producer side of an asynchronous pipeline, handling + synchronization and state management for producing data. It provides methods for + acquiring, committing, and advancing through pipeline stages. + + :ivar __pipeline: The asynchronous pipeline this producer belongs to + :type __pipeline: PipelineAsync + :ivar __state: The current state of the producer in the pipeline + :type __state: PipelineState + :ivar __group: The cooperative group this producer operates in + :type __group: CooperativeGroup + + **Examples:** + + .. code-block:: python + + pipeline = PipelineAsync.create(...) + producer = pipeline.create_producer(producer_group, stages) + for i in range(iterations): + handle = producer.acquire_and_advance() # Wait for buffer to be empty + # Produce data + producer.commit(handle) # Signal data is ready + # An alternative way to do this is: + # handle.commit() # Signal data is ready + """ + + __pipeline: PipelineAsync + __state: PipelineState + __group: CooperativeGroup + + class ImmutableResourceHandle(ImmutableResourceHandle): + @property + def barrier(self): + """Get the barrier pointer for the current pipeline stage. + + :return: Pointer to the barrier for the current stage + :rtype: cute.Pointer + """ + return self.get_origin().producer_get_barrier( + self._ImmutableResourceHandle__immutable_state + ) + + def commit(self): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + self.get_origin().producer_commit( + self._ImmutableResourceHandle__immutable_state + ) + + def __init__(self, pipeline, state, group: CooperativeGroup): + """Initialize a new Producer instance. + + :param pipeline: The pipeline this producer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self.__pipeline = pipeline + self.__state = state + self.__group = group + + def acquire( + self, + try_acquire_token: Optional[Boolean] = None, + ) -> ImmutableResourceHandle: + """Wait for the current buffer to be empty before producing data. + This is a blocking operation. + + :param try_acquire_token: Optional token to try to acquire the buffer + :type try_acquire_token: Optional[Boolean] + :return: A handle to the producer for committing the data + :rtype: ImmutableResourceHandle + """ + self.__pipeline.producer_acquire(self.__state, try_acquire_token) + handle = PipelineProducer.ImmutableResourceHandle( + self.__pipeline, self.__state.clone() + ) + return handle + + def advance(self): + """Move to the next pipeline stage.""" + self.__state.advance() + + def acquire_and_advance( + self, try_acquire_token: Optional[Boolean] = None + ) -> ImmutableResourceHandle: + """Wait for the current buffer to be empty before producing data. + Then advance to the next stage. + This is a blocking operation. + + :param try_acquire_token: Optional token to try to acquire the buffer + :type try_acquire_token: Optional[Boolean] + :return: A handle to the producer for committing the data + :rtype: ImmutableResourceHandle + """ + handle = self.acquire(try_acquire_token) + self.advance() + return handle + + def try_acquire(self) -> Boolean: + """Try to acquire the current buffer without blocking. + + :return: True if acquisition was successful, False otherwise + :rtype: Boolean + """ + return self.__pipeline.producer_try_acquire(self.__state) + + def commit(self, handle: Optional[ImmutableResourceHandle] = None): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + if handle is not None: + assert ( + handle.get_origin() is self + ), "ResourceHandle does not belong to this PipelineProducer instance" + handle.commit() + else: + self.__pipeline.producer_commit(self.__state) + + def tail(self): + """Ensure all used buffers are properly synchronized before producer exit. + This should be called before the producer finishes to avoid dangling signals. + """ + self.__pipeline.producer_tail(self.__state) + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + # TODO: need to handle pipeline as well + return self.__state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Producer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Producer instance with state initialized from values + :rtype: Producer + """ + return PipelineProducer( + self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group + ) + +class PipelineConsumer: + """A class representing a consumer in an asynchronous pipeline. + + The Consumer class manages the consumer side of an asynchronous pipeline, handling + synchronization and state management for consuming data. It provides methods for + waiting, releasing, and advancing through pipeline stages. + + :ivar __pipeline: The asynchronous pipeline this consumer belongs to + :type __pipeline: PipelineAsync + :ivar __state: The current state of the consumer in the pipeline + :type __state: PipelineState + :ivar __group: The cooperative group this consumer operates in + :type __group: CooperativeGroup + + **Examples:** + .. code-block:: python + + pipeline = PipelineAsync.create(...) + consumer = pipeline.create_consumer(consumer_group, stages) + for i in range(iterations): + handle = consumer.wait_and_advance() # Wait for data to be ready + # Consume data + consumer.release(handle) # Signal buffer is empty + # An alternative way to do this is: + # handle.release() # Signal buffer is empty + """ + + __pipeline: PipelineAsync + __state: PipelineState + __group: CooperativeGroup + + class ImmutableResourceHandle(ImmutableResourceHandle): + def release(self): + """Signal that data production is complete for the current stage. + This allows consumers to start processing the data. + """ + self.get_origin().consumer_release( + self._ImmutableResourceHandle__immutable_state + ) + + def __init__(self, pipeline, state: PipelineState, group: CooperativeGroup): + """Initialize a new Consumer instance. + + :param pipeline: The pipeline this consumer belongs to + :type pipeline: PipelineAsync + :param state: Initial pipeline state + :type state: PipelineState + :param group: The cooperative group for synchronization + :type group: CooperativeGroup + """ + self.__pipeline = pipeline + self.__group = group + self.__state = state + + def wait(self, try_wait_token: Optional[Boolean] = None) -> ImmutableResourceHandle: + """Wait for data to be ready in the current buffer. + This is a blocking operation. + + :param try_wait_token: Optional token to try to wait for the buffer + :type try_wait_token: Optional[Boolean] + :return: A handle to the consumer for releasing the data + :rtype: PipelineConsumerHandle + """ + self.__pipeline.consumer_wait(self.__state, try_wait_token) + handle = PipelineConsumer.ImmutableResourceHandle( + self.__pipeline, self.__state.clone() + ) + return handle + + def advance(self): + """Move to the next pipeline stage.""" + self.__state.advance() + + def wait_and_advance( + self, try_wait_token: Optional[Boolean] = None + ) -> ImmutableResourceHandle: + """Wait for data to be ready in the current buffer. + Then advance to the next stage. + This is a blocking operation. + + :param try_wait_token: Optional token to try to wait for the buffer + :type try_wait_token: Optional[Boolean] + :return: A handle to the consumer for releasing the data + :rtype: PipelineConsumerHandle + """ + handle = self.wait(try_wait_token) + self.advance() + return handle + + def try_wait(self) -> Boolean: + """Try to check if data is ready without blocking. + + :return: True if data is ready, False otherwise + :rtype: Boolean + """ + return self.__pipeline.consumer_try_wait(self.__state) + + def release(self, handle: Optional[ImmutableResourceHandle] = None): + """Signal that data consumption is complete for the current stage. + This allows producers to start producing new data. + """ + if handle is not None: + assert ( + handle.get_origin() is self + ), "ResourceHandle does not belong to this PipelineConsumer instance" + handle.release() + else: + self.__pipeline.consumer_release(self.__state) + + def __extract_mlir_values__(self): + """Extract MLIR values from the current state. + + :return: List of MLIR values representing the current state + :rtype: list + """ + return self.__state.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new Consumer instance from MLIR values. + + :param values: MLIR values to initialize the state + :type values: Any + :return: New Consumer instance with state initialized from values + :rtype: Consumer + """ + # TODO: need to call pipeline.__new_from_mlir_values__ recursively + return PipelineConsumer( + self.__pipeline, self.__state.__new_from_mlir_values__(values), self.__group + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ee5777cad35487f30b8705ff19747405d11194 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/torch.py @@ -0,0 +1,311 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +import ctypes +from math import prod +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Type, Union + +from cutlass.cute.typing import ( + Numeric, + Boolean, + Float, + Integer, + TFloat32, + Float8E4M3B11FNUZ, + Float8E4M3FN, + Float8E5M2, + Float8E8M0FNU, + Float4E2M1FN, + Tensor, +) +from cutlass.cute.runtime import from_dlpack +import cutlass.cute as cute +import torch +import cuda.bindings.driver as cuda + + +def dtype(ty: Type[Numeric]): + """ + Return the corresponding torch.dtype per the given DSL type + """ + torch_dtype = getattr(torch, ty.__name__.lower(), None) + + torch_type_map = { + Boolean: torch.bool, + # TFloat32 is just alias of float32 + TFloat32: torch.float32, + Float8E5M2: torch.float8_e5m2, + Float8E4M3FN: torch.float8_e4m3fn, + Float8E4M3B11FNUZ: torch.float8_e4m3fnuz, + } + if torch_dtype is None: + torch_dtype = torch_type_map.get(ty) + + if torch_dtype is None: + raise TypeError(f"{ty} is not supported by torch") + return torch_dtype + + +def as_tensor(pointer, shape, torch_type): + """Convert a pointer to a torch tensor""" + if torch_type.itemsize == 1: + cytype = ctypes.c_uint8 + elif torch_type.itemsize == 2: + cytype = ctypes.c_uint16 + elif torch_type.itemsize == 4: + cytype = ctypes.c_uint32 + elif torch_type.itemsize == 8: + cytype = ctypes.c_uint64 + else: + raise ValueError(f"Unsupported torch dtype: {torch_type}") + cpointer = ctypes.cast(pointer, ctypes.POINTER(cytype)) + arr = (cpointer._type_ * prod(shape)).from_address( + ctypes.addressof(cpointer.contents) + ) + return torch.frombuffer(arr, dtype=torch_type).view(*shape) + + +@dataclass +class ScalarInitConfig: + """Configuration for scalar initialization""" + + value: float = 0.0 + + +@dataclass +class RandomInitConfig: + """Configuration for random initialization""" + + min_val: int = -2 + max_val: int = 2 + + +@dataclass +class GaussianInitConfig: + """Configuration for Gaussian initialization""" + + mean: float = 0.0 + std: float = 1.0 + scale: float = 1.0 + + +class TensorInitType(Enum): + """Enumeration of tensor initialization types""" + + SKIP = "skip" + SCALAR = "scalar" + RANDOM = "random" + GAUSSIAN = "gaussian" + + +def create_and_permute_torch_tensor( + shape, + dtype: "torch.dtype", + permute_order=None, + init_type: TensorInitType = TensorInitType.RANDOM, + init_config: Optional[ + Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] + ] = None, + device: Optional[torch.device] = None, +) -> "torch.Tensor": + """ + Create a torch tensor with specified shape and dtype. Optionally permute it and initialize it with specified init type and config + """ + init_dtype = torch.int32 if init_type == TensorInitType.RANDOM else torch.float32 + init_torch_tensor = torch.empty(*shape, dtype=init_dtype, device=device) + if init_type == TensorInitType.SKIP: + assert init_config is None + f32_torch_tensor = init_torch_tensor + elif init_type == TensorInitType.SCALAR: + if init_config is None: + init_config = ScalarInitConfig() + else: + if not isinstance(init_config, ScalarInitConfig): + raise ValueError("init_config must be ScalarInitConfig()") + f32_torch_tensor = init_torch_tensor.fill_(init_config.value) + elif init_type == TensorInitType.RANDOM: + if init_config is None: + init_config = RandomInitConfig() + else: + if not isinstance(init_config, RandomInitConfig): + raise ValueError("init_config must be RandomInitConfig()") + f32_torch_tensor = init_torch_tensor.random_( + init_config.min_val, init_config.max_val + ).to(dtype=torch.float32) + elif init_type == TensorInitType.GAUSSIAN: + if init_config is None: + init_config = GaussianInitConfig() + else: + if not isinstance(init_config, GaussianInitConfig): + raise ValueError("init_config must be GaussianInitConfig()") + f32_torch_tensor = init_torch_tensor.normal_(init_config.mean, init_config.std) + f32_torch_tensor = f32_torch_tensor * init_config.scale + else: + raise ValueError(f"Invalid init type: {init_type}") + + if permute_order is not None: + f32_torch_tensor = f32_torch_tensor.permute(permute_order) + + dtype_torch_tensor = f32_torch_tensor.to(dtype=dtype) + + return dtype_torch_tensor + + +def convert_cute_tensor( + f32_torch_tensor: "torch.Tensor", + cute_tensor: Tensor, + dtype: Type[Numeric], + is_dynamic_layout: bool = True, +) -> Tensor: + """ + Change the value of the cute tensor to make its value converted from a fp32 torch tensor. + Used for fp8 types tensor creatation now. + """ + # if torch_tensor is on cpu, create a gpu copy + if f32_torch_tensor.device.type == "cpu": + f32_torch_tensor = f32_torch_tensor.cuda() + + # Fp8 type need explicit type conversion + if dtype in { + Float8E5M2, + Float8E4M3FN, + Float8E8M0FNU, + Float4E2M1FN, + }: + fp32_cute_tensor = from_dlpack(f32_torch_tensor) + if is_dynamic_layout: + fp32_cute_tensor = fp32_cute_tensor.mark_layout_dynamic( + f32_torch_tensor.dim_order()[-1] + ) + # Copy and convert from f32 cute tensor to dtype cute tensor + cute.testing.convert(fp32_cute_tensor, cute_tensor) + return cute_tensor + + +def default_stream() -> cuda.CUstream: + """ + Get default CUstream from torch stream + """ + torch_stream = torch.cuda.default_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + return stream + + +def current_stream() -> cuda.CUstream: + """ + Get current CUstream from torch stream + """ + torch_stream = torch.cuda.current_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + return stream + + +def matrix( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + cutlass_dtype: Type[Numeric], + init_type: TensorInitType = TensorInitType.RANDOM, + init_config: Optional[ + Union[RandomInitConfig, ScalarInitConfig, GaussianInitConfig] + ] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """ + Create a torch tensor for matrix + + :param l: length of the matrix + :param mode0: mode0 of the matrix + :param mode1: mode1 of the matrix + :param is_mode0_major: whether the matrix is mode0 major + :param cutlass_dtype: cutlass dtype of the matrix + :param init_type: type of initialization + :param init_config: configuration for initialization + :param device: target torch device + """ + + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + torch_dtype = torch.int8 + else: + torch_dtype = dtype(cutlass_dtype) + + if init_type == TensorInitType.RANDOM and init_config is None: + if torch_dtype.is_signed: + min_val = -2 + max_val = 2 + else: + min_val = 0 + max_val = 4 + init_config = RandomInitConfig(min_val=min_val, max_val=max_val) + + # Create dtype torch tensor + torch_tensor = create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=init_type, + init_config=init_config, + device=device, + ) + + return torch_tensor + + +def cute_tensor_like( + data_ref: torch.Tensor, + cutlass_dtype: Type[Numeric], + is_dynamic_layout: bool, + assumed_align: Optional[int] = None, +) -> tuple[Tensor, torch.Tensor]: + """ + Create a cute tensor use a torch tensor as the data source + + :param data_ref: torch tensor as the data source + :param cutlass_dtype: cutlass dtype of the cute tensor + :param is_dynamic_layout: whether the cute tensor uses dynamic layout + :param assumed_align: assumed alignment of the cute tensor + """ + + # allocate device buffer for cute tensor + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + torch_dtype = torch.int8 + else: + torch_dtype = dtype(cutlass_dtype) + torch_tensor = torch.empty_like(data_ref, dtype=torch_dtype, device="cuda") + + # create cute tensor using the device buffer + cute_tensor = from_dlpack(torch_tensor, assumed_align=assumed_align) + cute_tensor.element_type = cutlass_dtype + if is_dynamic_layout: + for i, stride in enumerate(torch_tensor.stride()): + if stride == 1: + leading_dim = i + break + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + + # initialize the cute tensor data + if cutlass_dtype.is_float and cutlass_dtype.width <= 8: + cute_tensor = convert_cute_tensor( + data_ref.to(dtype=torch.float32), + cute_tensor, + cutlass_dtype, + is_dynamic_layout, + ) + else: + torch_tensor.copy_(data_ref.to(dtype=torch_dtype)) + + return cute_tensor, torch_tensor diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aec0a186d7a8fc18d65637e97905c7cd5702310d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/__init__.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .static_persistent_tile_scheduler import ( + WorkTileInfo, + PersistentTileSchedulerParams, + StaticPersistentTileScheduler, +) + +from .hardware_info import ( + HardwareInfo, +) + +from .blackwell_helpers import ( + compute_epilogue_tile_shape, + get_smem_store_op, + get_tmem_load_op, + get_num_tmem_alloc_cols, + make_smem_layout_a, + make_smem_layout_b, + make_smem_layout_epi, + make_trivial_tiled_mma, + make_blockscaled_trivial_tiled_mma, +) + +from .hopper_helpers import ( + sm90_get_smem_store_op, +) + +from .blockscaled_layout import ( + BlockScaledBasicChunk, + tile_atom_to_shape_SF, + make_smem_layout_sfa, + make_smem_layout_sfb, + make_tmem_layout_sfa, + make_tmem_layout_sfb, +) + +from .grouped_gemm_tile_scheduler_helper import ( + GroupSearchResult, + GroupedGemmGroupSearchState, + GroupedGemmTileSchedulerHelper, + create_initial_search_state, +) + +from .tensormap_manager import ( + TensorMapUpdateMode, + TensorMapManager, +) + +from .smem_allocator import SmemAllocator + +from .layout import LayoutEnum + +from .smem_capacity import ( + get_smem_capacity_in_bytes, +) + +from .distributed_helpers import ( + spin_lock_wait, + spin_lock_multimem_arrive, + multimem_ld_reduce_8xf16, + multimem_ld_reduce_4xf32, + multimem_ld_reduce_8xbf16, + multimem_ld_reduce_16xe4m3, + multimem_ld_reduce_16xe5m2, + multimem_st_4xb32, + sm_wise_inter_gpu_multimem_barrier, +) + +__all__ = [ + "get_smem_capacity_in_bytes", + "SmemAllocator", + "LayoutEnum", + "WorkTileInfo", + "PersistentTileSchedulerParams", + "StaticPersistentTileScheduler", + "TensorMapUpdateMode", + "TensorMapManager", + "GroupSearchResult", + "GroupedGemmGroupSearchState", + "create_initial_search_state", + "GroupedGemmTileSchedulerHelper", + "HardwareInfo", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1341756f3584f89b0c201631445beb91c34dc29e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/ampere_helpers.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from enum import Enum +from typing_extensions import deprecated +import warnings + + +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM80_SMEM_CAPACITY_BYTES = (164 - 1) * 1024 + SM86_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 + SM89_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm80": SmemCapacity.SM80_SMEM_CAPACITY_BYTES.value, + "sm86": SmemCapacity.SM86_SMEM_CAPACITY_BYTES.value, + "sm89": SmemCapacity.SM89_SMEM_CAPACITY_BYTES.value, +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb6bf4dbfa3e73f058037e79b0999697d720502 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blackwell_helpers.py @@ -0,0 +1,1135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from enum import Enum +from math import log2, ceil +from typing import List, Type, Union, Tuple +from typing_extensions import deprecated +import warnings + +from cutlass.cutlass_dsl import ( + Float16, + BFloat16, + TFloat32, + Float32, + Uint8, + Int8, + Float8E4M3FN, + Float8E5M2, + Float4E2M1FN, + Numeric, + NumericMeta, + dsl_user_op, +) +import cutlass.cute as cute +from cutlass.cute.nvgpu.common import CopyUniversalOp +from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp, StMatrix16x8x8bOp +from cutlass.cute.nvgpu.tcgen05 import ( + MmaF16BF16Op, + MmaTF32Op, + MmaI8Op, + MmaFP8Op, + MmaMXF8Op, + MmaMXF4Op, + MmaMXF4NVF4Op, + OperandSource, + OperandMajorMode, + CtaGroup, + Ld16x64bOp, + Ld16x128bOp, + Ld16x256bOp, + Ld16x32bx2Op, + Ld32x32bOp, + Repetition, + Pack, + find_tmem_tensor_col_offset, + SmemLayoutAtomKind, + make_smem_layout_atom, + tile_to_mma_shape, + is_tmem_load, + get_tmem_copy_properties, +) +from cutlass.cute.nvgpu.cpasync import ( + CopyBulkTensorTileG2SMulticastOp, + CopyBulkTensorTileG2SOp, +) +from cutlass.utils.layout import LayoutEnum + + +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM100_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 + SM120_SMEM_CAPACITY_BYTES = (100 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm100": SmemCapacity.SM100_SMEM_CAPACITY_BYTES.value, + "sm120": SmemCapacity.SM120_SMEM_CAPACITY_BYTES.value, +} + + +@dsl_user_op +def compute_epilogue_tile_shape( + cta_tile_shape: cute.Shape, + use_2cta_instrs: bool, + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + *, + layout_c: LayoutEnum = None, + elem_ty_c: Union[Type[Numeric], None] = None, + loc=None, + ip=None, +) -> cute.Tile: + """Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. + + :param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile, where + cta_tile_shape[0] corresponds to the height (M) and cta_tile_shape[1] + corresponds to the width (N) of the tile. + :type cta_tile_shape: cute.Shape + :param use_2cta_instrs: A flag indicating whether the configuration is for a 2SM setup. + :type use_2cta_instrs: bool + :param layout_d: The layout enum of the output tensor D. + :type layout_d: LayoutEnum + :param elem_ty_d: The element type of output tensor D. + :type elem_ty_d: Type[Numeric] + :param layout_c: The layout enum of the input tensor C. Defaults to None. + :type layout_c: LayoutEnum, optional + :param elem_ty_c: The element type for input tensor C. Defaults to None. + :type elem_ty_c: Union[Type[Numeric], None], optional + + :return: Returns epilog tiler, which is used in subsequent epilog partitions. + :rtype: cute.Tile + + :raises ValueError: If the computed tile cute.size does not meet minimum requirements based on CTA dimensions. + """ + + def validate_type(ty, ty_name): + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + if elem_ty_c is not None: + validate_type(elem_ty_c, "elem_ty_c") + + cta_m, cta_n = cta_tile_shape[:2] + (warp_m, warp_n) = (2, 2) if (cta_m == 64 and use_2cta_instrs) else (4, 1) + disable_source = elem_ty_c == None + max_bits = ( + elem_ty_d.width if disable_source else max(elem_ty_c.width, elem_ty_d.width) + ) + + dp_full = 32 + tile_m = min(cta_m, dp_full * warp_m) + n_perf = 0 + if disable_source: + if max_bits == 4: + compute_elts = 8192 + else: + compute_elts = 4096 + n_perf = compute_elts // tile_m + else: + if max_bits == 32: + n_perf = 16 if (cta_m > 64 and cta_n <= 128) else 32 + elif max_bits == 16: + n_perf = 32 if cta_n <= 128 else 64 + else: + n_perf = 64 + + d_is_m_major = layout_d.is_m_major_c() + c_is_m_major = True if layout_c is None else layout_c.is_m_major_c() + + n_min_d = ( + 8 * warp_n + if d_is_m_major + else (128 * warp_n if elem_ty_d.width == 6 else 128 // elem_ty_d.width * warp_n) + ) + n_min_c = ( + 8 * warp_n + if (c_is_m_major or disable_source) + else (128 * warp_n if elem_ty_c.width == 6 else 128 // elem_ty_c.width * warp_n) + ) + tile_n = min(cta_n, max(n_perf, n_min_c, n_min_d)) + + if cta_n < n_min_c or cta_n < n_min_d: + raise ValueError(f"CTA tile too small: {cta_tile_shape=}") + + # stride by tmem warp layout and return a by-mode tiler + tile_m_layout = cute.make_layout(tile_m, loc=loc, ip=ip) + tile_n_layout = cute.make_layout( + (tile_n // warp_n, warp_n), stride=(1, cta_n // warp_n), loc=loc, ip=ip + ) + return (tile_m_layout, cute.coalesce(tile_n_layout, loc=loc, ip=ip)) + + +@dsl_user_op +def get_smem_store_op( + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + tiled_tmem_load: cute.TiledCopy, + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """Selects the largest vectorized smem store atom available subject to + constraint of gmem layout and chosen TMEM_LOAD's thread-value ownership. + + :param layout_d: The layout enum of the output tensor D. + :type layout_d: LayoutEnum + :param elem_ty_d: The element type for output tensor D. + :type elem_ty_d: Type[Numeric] + :param elem_ty_acc: The element type for accumulator. + :type elem_ty_acc: Type[Numeric] + :param tiled_tmem_load: An instance of TiledCopy that represents the tmem load operation. + :type tiled_tmem_load: cute.TiledCopy + + :return: Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. + :rtype: cute.CopyAtom + """ + + def validate_type(ty, ty_name): + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + validate_type(elem_ty_acc, "elem_ty_acc") + + is_m_major = layout_d.is_m_major_c() + is_n_major = layout_d.is_n_major_c() + + if not is_tmem_load(tiled_tmem_load): + return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) + + num_dp, num_bits, num_rep, pack = get_tmem_copy_properties(tiled_tmem_load) + + use_stmatrix_m8n8_4x = ( + all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 32, + is_n_major, + num_dp == 16, + num_bits == 128, + num_rep in (2, 4, 8, 16, 32, 64), + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 256, + num_rep in (2, 4, 8, 16, 32), + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 16, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 128, + num_rep in (2, 4, 8, 16, 32, 64), + pack == Pack.PACK_16b_IN_32b, + ] + ) + ) + use_stmatrix_m16n8_4x = all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 8, + is_m_major, + num_dp == 16, + num_bits == 256, + num_rep in (4, 8, 16, 32), + pack == Pack.NONE, + ] + ) + use_stmatrix_m8n8_2x = ( + all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 32, + is_n_major, + num_dp == 16, + num_bits == 128, + num_rep == 1, + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 256, + num_rep == 1, + pack == Pack.NONE, + ] + ) + or all( + [ + elem_ty_acc.width == 16, + elem_ty_d.width == 16, + num_dp == 16, + num_bits == 128, + num_rep == 1, + pack == Pack.PACK_16b_IN_32b, + ] + ) + ) + use_stmatrix_m16n8_2x = all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 8, + is_m_major, + num_dp == 16, + num_bits == 256, + num_rep == 2, + pack == Pack.NONE, + ] + ) + use_stmatrix_m16n8_1x = all( + [ + elem_ty_acc.width == 32, + elem_ty_d.width == 8, + is_m_major, + num_dp == 16, + num_bits == 256, + num_rep == 1, + pack == Pack.NONE, + ] + ) + + if use_stmatrix_m8n8_4x: + op = StMatrix8x8x16bOp(is_m_major, 4) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m8n8_2x: + op = StMatrix8x8x16bOp(is_m_major, 2) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m16n8_4x: + op = StMatrix16x8x8bOp(4) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m16n8_2x: + op = StMatrix16x8x8bOp(2) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + elif use_stmatrix_m16n8_1x: + op = StMatrix16x8x8bOp(1) + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + else: + op = CopyUniversalOp() + return cute.make_copy_atom(op, elem_ty_d, loc=loc, ip=ip) + + +@dsl_user_op +def get_tmem_load_op( + cta_tile_shape: cute.Shape, + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + epi_tile: cute.Tile, + use_2cta_instrs: bool, + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """Finds a performant TMEM_LOAD copy op for the selected epilogue + tile (epi_tile), element types, and tcgen05.mma instruction used. + + :param cta_tile_shape: A tuple or list representing the dimensions of the CTA tile. + :type cta_tile_shape: cute.Shape + :param layout_d: The layout enum of the output tensor D. + :type layout_d: LayoutEnum + :param elem_ty_d: The element type for output tensor D. + :type elem_ty_d: Type[Numeric] + :param elem_ty_acc: The element type for accumulation. + :type elem_ty_acc: Type[Numeric] + :param epi_tile: The epilogue tile configuration. + :type epi_tile: cute.Tile + :param use_2cta_instrs: A flag indicating whether the configuration is for 2 SMs. + :type use_2cta_instrs: bool + + :return: An instance of Sm100TmemLoad with the computed configuration. + :rtype: cute.CopyAtom + + :raises ValueError: If the function cannot handle the given combination of accumulation + and dimension types, or if it cannot determine the appropriate configuration based on + the input parameters. + """ + is_m_major = layout_d.is_m_major_c() + + acc_bits = elem_ty_acc.width + d_bits = elem_ty_d.width + + tmem_warp_shape_mn = ( + (2, 2) if (cta_tile_shape[0] == 64 and use_2cta_instrs) else (4, 1) + ) + epilog_tile_shape_mn = cute.product_each( + cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip + ) + epilog_warp_tile_shape_mn = cute.shape_div( + epilog_tile_shape_mn, tmem_warp_shape_mn, loc=loc, ip=ip + ) + + num_dp = cute.size(epilog_warp_tile_shape_mn[0], loc=loc, ip=ip) + if num_dp not in {16, 32}: + raise ValueError("Cta tile and 2sm config does not generate correct num dp.") + + num_col_bits = cute.size(epilog_warp_tile_shape_mn[1], loc=loc, ip=ip) * acc_bits + + tmem_dp = 0 + tmem_bit = 0 + tmem_rep = 0 + tmem_pack16b = False + if acc_bits == 32 and d_bits == 32: + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 16 + tmem_bit = 128 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 16: + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 16 + tmem_bit = 256 + else: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 8: + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 16 + tmem_bit = 32 + else: + if is_m_major: + tmem_dp = 16 + tmem_bit = 256 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 16 and d_bits == 16: + tmem_pack16b = True + if num_dp == 16: + if is_m_major: + tmem_dp = 16 + tmem_bit = 128 + else: + tmem_dp = 16 + tmem_bit = 128 + else: + if is_m_major: + tmem_dp = 16 + tmem_bit = 128 + else: + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 6: + if not num_dp == 32: + raise ValueError("Num dp must be 32.") + tmem_dp = 32 + tmem_bit = 32 + elif acc_bits == 32 and d_bits == 4: + if not num_dp == 32: + raise ValueError("Num dp must be 32.") + tmem_dp = 32 + tmem_bit = 32 + else: + raise ValueError( + f"Can not handle acc/d type combination: {elem_ty_acc=}, {elem_ty_d=}" + ) + + num_bit_div = tmem_bit + if tmem_dp == 16 and tmem_bit == 32: + num_bit_div = 64 + + if (num_col_bits % (num_bit_div * 128) == 0) and ( + (tmem_dp == 16 and tmem_bit == 64) + or (tmem_dp == 16 and tmem_bit == 32) + or (tmem_dp == 32 and tmem_bit == 32) + ): + tmem_rep = 128 + elif (num_col_bits % (num_bit_div * 64) == 0) and ( + (tmem_dp == 16 and tmem_bit == 128) + or (tmem_dp == 16 and tmem_bit == 64) + or (tmem_dp == 16 and tmem_bit == 32) + or (tmem_dp == 32 and tmem_bit == 32) + ): + tmem_rep = 64 + elif num_col_bits % (num_bit_div * 32) == 0: + tmem_rep = 32 + elif num_col_bits % (num_bit_div * 16) == 0: + tmem_rep = 16 + elif num_col_bits % (num_bit_div * 8) == 0: + tmem_rep = 8 + elif num_col_bits % (num_bit_div * 4) == 0: + tmem_rep = 4 + elif num_col_bits % (num_bit_div * 2) == 0: + tmem_rep = 2 + elif num_col_bits % (num_bit_div * 1) == 0: + tmem_rep = 1 + else: + raise ValueError("Can not pick tmem_rep based on cta tile shape and tmem atom.") + + if tmem_dp == 16 and tmem_bit == 64: + op = Ld16x64bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + elif tmem_dp == 16 and tmem_bit == 128: + op = Ld16x128bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + elif tmem_dp == 16 and tmem_bit == 256: + op = Ld16x256bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + elif tmem_dp == 16 and tmem_bit == 32: + op = Ld16x32bx2Op( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + + elif tmem_dp == 32 and tmem_bit == 32: + op = Ld32x32bOp( + Repetition(tmem_rep), Pack.PACK_16b_IN_32b if tmem_pack16b else Pack.NONE + ) + return cute.make_copy_atom(op, elem_ty_acc, loc=loc, ip=ip) + else: + raise ValueError() + + +def get_num_tmem_alloc_cols( + tmem_tensors: Union[cute.Tensor, List[cute.Tensor]], rounding=True +) -> int: + """Get the total number of TMEM allocation columns for the given TMEM tensors. + + :param tmem_tensors: The TMEM tensors to get the number of allocation columns for. + :type tmem_tensors: Union[cute.Tensor, List[cute.Tensor]] + :param rounding: Whether to round up the number of allocation columns to the nearest power of 2. + :type rounding: bool + + :return: The total number of TMEM allocation columns. + :rtype: int + + :raises ValueError: If the number of TMEM allocation columns exceeds the maximum capacity of 512 or is less than 32. + """ + # Turn tmem_tensors into a list + if isinstance(tmem_tensors, cute.Tensor): + tmem_tensors = [tmem_tensors] + + # For each tensor in tmem_tensors, find the tmem_tensor_col_offset + num_tmem_alloc_cols_per_tensor = [ + find_tmem_tensor_col_offset(t) for t in tmem_tensors + ] + + # Sum up the num_tmem_alloc_cols_per_tensor + num_tmem_alloc_cols = sum(num_tmem_alloc_cols_per_tensor) + + # Round up num_tmem_cols_total to the nearest power of 2 + if rounding: + num_tmem_alloc_cols = 1 << ceil(log2(num_tmem_alloc_cols)) + + # Validate the number of TMEM allocation columns + SM100_TMEM_CAPACITY_COLUMNS = 512 + SM100_TMEM_MIN_ALLOC_COLUMNS = 32 + if ( + num_tmem_alloc_cols > SM100_TMEM_CAPACITY_COLUMNS + or num_tmem_alloc_cols < SM100_TMEM_MIN_ALLOC_COLUMNS + ): + raise ValueError( + f"TMEM allocation columns {num_tmem_alloc_cols} exceeds the maximum capacity of {SM100_TMEM_CAPACITY_COLUMNS} or less than {SM100_TMEM_MIN_ALLOC_COLUMNS}" + ) + return num_tmem_alloc_cols + + +def get_smem_layout_atom_ab( + major_mode: OperandMajorMode, + element_type: Type[Numeric], + smem_shape_mn_k: Tuple[int, int], + *, + loc=None, + ip=None, +) -> SmemLayoutAtomKind: + """Simple heuristics to select the optimal SMEM layout atom based on the + majorness, the data type, and the major mode size. + + :param major_mode: The major mode for the SMEM tensor is K major. + :type major_mode: OperandMajorMode + :param element_type: The element type for the SMEM tensor. + :type element_type: Type[Numeric] + :param smem_shape_mn_k: The shape of the SMEM tensor. + :type smem_shape_mn_k: Tuple[int, int] + + :return: The SMEM layout atom kind + :rtype: SmemLayoutAtomKind + """ + is_k_major = major_mode == OperandMajorMode.K + major_mode_size = smem_shape_mn_k[1] if is_k_major else smem_shape_mn_k[0] + + assert major_mode_size % 8 == 0 + sw128_num_contiguous_bits = 1024 + sw64_num_contiguous_bits = 512 + sw32_num_contiguous_bits = 256 + inter_num_contiguous_bits = 128 + major_mode_size_bits = major_mode_size * element_type.width + assert major_mode_size_bits % inter_num_contiguous_bits == 0 + + if not is_k_major: + if (element_type.width == 32) and ( + major_mode_size_bits % sw128_num_contiguous_bits == 0 + ): + return SmemLayoutAtomKind.MN_SW128_32B + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return SmemLayoutAtomKind.MN_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return SmemLayoutAtomKind.MN_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return SmemLayoutAtomKind.MN_SW32 + return SmemLayoutAtomKind.MN_INTER + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return SmemLayoutAtomKind.K_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return SmemLayoutAtomKind.K_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return SmemLayoutAtomKind.K_SW32 + return SmemLayoutAtomKind.K_INTER + + +@dsl_user_op +def make_smem_layout_a( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + a_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps with: + 1. Get the partitioned shape of the A tensor based on the tiled_mma & MMA tiler. + 2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size. + 3. cute.Tile the SMEM layout atom to the MMA tile shape. + 4. Stage the SMEM layout based on the number of stages. + + :param tiled_mma: The tiled MMA used to partition tensor A + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The MMA tile shape + :type mma_tiler_mnk: cute.cute.Tile + :param a_dtype: The element type for tensor A + :type a_dtype: Type[Numeric] + :param num_stages: The number of pipeline stages for tensor A + :type num_stages: int + + :return: SMEM layout for tensor A + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip) + ) + a_smem_shape_mn_k = ( + cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], + cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], + ) + a_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom_ab( + tiled_mma.op.a_major_mode, + a_dtype, + a_smem_shape_mn_k, + loc=loc, + ip=ip, + ), + a_dtype, + loc=loc, + ip=ip, + ) + a_smem_layout_staged = tile_to_mma_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, num_stages, loc=loc, ip=ip), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + loc=loc, + ip=ip, + ) + return a_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_b( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + b_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps: + 1. Get the partitioned shape of the B tensor based on the tiled_mma & MMA tiler. + 2. Select the heuristic SMEM layout atom based on the B tensor's majorness, the data type, and the major mode size. + 3. cute.Tile the SMEM layout atom to the MMA tile shape. + 4. Stage the SMEM layout based on the number of stages. + + :param tiled_mma: The tiled MMA which is used to partition the B tensor. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The MMA tile shape. + :type mma_tiler_mnk: cute.cute.Tile + :param b_dtype: The element type for the B tensor. + :type b_dtype: Type[Numeric] + :param num_stages: The stage of the B tensor. + :type num_stages: int + + :return: SMEM layout for the B tensor. + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + is_k_major = tiled_mma.op.b_major_mode == OperandMajorMode.K + b_smem_shape = tiled_mma.partition_shape_B( + cute.dice(mma_tiler_mnk, (None, 1, 1), loc=loc, ip=ip) + ) + b_smem_shape_nk = ( + cute.size(b_smem_shape[0][0], loc=loc, ip=ip) * b_smem_shape[1], + cute.size(b_smem_shape[0][1], loc=loc, ip=ip) * b_smem_shape[2], + ) + b_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom_ab( + tiled_mma.op.b_major_mode, + b_dtype, + b_smem_shape_nk, + loc=loc, + ip=ip, + ), + b_dtype, + loc=loc, + ip=ip, + ) + b_smem_layout_staged = tile_to_mma_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, num_stages, loc=loc, ip=ip), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + loc=loc, + ip=ip, + ) + + return b_smem_layout_staged + + +@dsl_user_op +def get_smem_layout_atom_epi( + layout: LayoutEnum, + element_type: Type[Numeric], + epi_tile: cute.Tile, + *, + loc=None, + ip=None, +) -> SmemLayoutAtomKind: + """Simple heuristics to select the optimal SMEM layout atom for epilog tensors. + + :param layout: The layout enum for the SMEM tensor. + :type layout: LayoutEnum + :param element_type: The element type for the SMEM tensor. + :type element_type: Type[Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + + :return: The SMEM layout atom kind + :rtype: SmemLayoutAtomKind + """ + # Get the max contiguous tile usable by TMA + tma_shape = tuple( + ( + # assumes get<0>(epi_tile) is coalesced and unit stride + cute.coalesce(cute.right_inverse(x, loc=loc, ip=ip), loc=loc, ip=ip).shape + if isinstance(x, cute.Layout) + else x + ) + for x in epi_tile + ) + + if layout.is_m_major_c(): + # ColMajor C/D (M-major) + return get_smem_layout_atom_ab( + OperandMajorMode.MN, element_type, tma_shape, loc=loc, ip=ip + ) + else: + # RowMajor C/D (N-major) + return get_smem_layout_atom_ab( + OperandMajorMode.K, element_type, tma_shape, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_smem_layout_epi( + epi_dtype: Type[Numeric], + epi_layout: LayoutEnum, + epi_tile: cute.Tile, + epi_stage: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """This function helps: + 1. Select the heuristic SMEM layout atom based on the epilog tile shape, + the epilog tensor's majorness, and the element type. + 2. cute.Tile the SMEM layout atom to the epilog tile shape. + 3. Stage the SMEM layout based on the number of stages. + + :param epi_dtype: The element type for the epilog tensor. + :type epi_dtype: Type[Numeric] + :param epi_layout: The layout enum for the epilog tensor. + :type epi_layout: LayoutEnum + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.cute.Tile + :param epi_stage: The stage of the epilog tensor. + :type epi_stage: int + + :return: SMEM layout for epilog tensors (usually C & D which are processed in the epilog) + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + epilog_shape = cute.product_each( + cute.shape(epi_tile, loc=loc, ip=ip), loc=loc, ip=ip + ) + + c_smem_layout_atom = make_smem_layout_atom( + get_smem_layout_atom_epi( + epi_layout, + epi_dtype, + epi_tile, + loc=loc, + ip=ip, + ), + epi_dtype, + loc=loc, + ip=ip, + ) + epi_smem_layout_staged = cute.tile_to_shape( + c_smem_layout_atom, + cute.append(epilog_shape, epi_stage, loc=loc, ip=ip), + order=((1, 0, 2) if not epi_layout.is_n_major_c() else (0, 1, 2)), + loc=loc, + ip=ip, + ) + + return epi_smem_layout_staged + + +@dsl_user_op +def make_trivial_tiled_mma( + ab_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + acc_dtype: Type[Numeric], + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, +) -> cute.TiledMma: + """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. + By default, the MMA atom is created with SMEM operand source for A. + + :param ab_dtype: Data type of operands A and B. + :type ab_dtype: type[Numeric] + :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). + :type a_leading_mode: tcgen05.OperandMajorMode + :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). + :type b_leading_mode: tcgen05.OperandMajorMode + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[Numeric] + :param cta_group: The CTA group to use. + :type cta_group: tcgen05.CtaGroup + :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mn: Tuple[int, int] + :param a_source: The source of operand A (SMEM by default or TMEM). + :type a_source: OperandSource + + :return: A tiled MMA atom. + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + """ + + if ab_dtype in {Float16, BFloat16}: + mma_op = MmaF16BF16Op( + ab_dtype, + acc_dtype, + (*mma_tiler_mn, 16), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype in {TFloat32, Float32}: + mma_op = MmaTF32Op( + (*mma_tiler_mn, 8), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype in { + Uint8, + Int8, + }: + mma_op = MmaI8Op( + ab_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype in {Float8E4M3FN, Float8E5M2}: + mma_op = MmaFP8Op( + ab_dtype, + acc_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + else: + raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") + + return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + +@dsl_user_op +def make_blockscaled_trivial_tiled_mma( + ab_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + sf_dtype: Type[Numeric], + sf_vec_size: int, + cta_group: CtaGroup, + mma_tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, +) -> cute.TiledMma: + """Make a BlockScaled tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. + By default, the MMA atom is created with SMEM operand source for A. + + :param ab_dtype: Data type of operands A and B. + :type ab_dtype: type[Numeric] + :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). + :type a_leading_mode: tcgen05.OperandMajorMode + :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). + :type b_leading_mode: tcgen05.OperandMajorMode + :param sf_dtype: Data type of the Scale Factor. + :type sf_dtype: type[Numeric] + :param sf_vec_size: The vector size of the Scale Factor. + :type sf_vec_size: int + :param cta_group: The CTA group to use. + :type cta_group: tcgen05.CtaGroup + :param mma_tiler_mn: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mn: Tuple[int, int] + :param a_source: The source of operand A (SMEM by default or TMEM). + :type a_source: OperandSource + + :return: A tiled MMA atom. + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + """ + if ab_dtype in {Float8E4M3FN, Float8E5M2}: + mma_op = MmaMXF8Op( + ab_dtype, + (*mma_tiler_mn, 32), + cta_group, + a_source, + a_leading_mode, + b_leading_mode, + ) + elif ab_dtype == Float4E2M1FN: + if sf_vec_size == 32: + mma_op = MmaMXF4Op( + (*mma_tiler_mn, 64), + cta_group, + a_source, + ) + elif sf_vec_size == 16: + mma_op = MmaMXF4NVF4Op( + sf_dtype, + (*mma_tiler_mn, 64), + cta_group, + a_source, + ) + else: + raise ValueError(f"unsupported sf_vec_size, got {sf_vec_size}") + else: + raise TypeError(f"unsupported ab_dtype, got {ab_dtype}") + + return cute.make_tiled_mma(cute.make_mma_atom(mma_op)) + + +@dsl_user_op +def cluster_shape_to_tma_atom_A( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for A based on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[1], loc=loc, ip=ip) == 1) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) + + +@dsl_user_op +def cluster_shape_to_tma_atom_B( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for Bbased on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == atom_sm_cnt) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) + + +@dsl_user_op +def cluster_shape_to_tma_atom_SFB( + cluster_shape_mnk: cute.Shape, atom_thr_id: cute.Layout, *, loc=None, ip=None +) -> Union[CopyBulkTensorTileG2SMulticastOp, CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom for SFB based on the number of SMs and the multicast flag. + + :param cluster_shape_mnk: The shape of the cluster + :type cluster_shape_mnk: cute.Shape + :param atom_thr_id: The thread ID of the atom + :type atom_thr_id: cute.Layout + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + :raise ValueError: If the cluster shape is not divisible by the atom SM count + """ + atom_sm_cnt = cute.size(atom_thr_id, loc=loc, ip=ip) + mcast = not (cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) == 1) + cluster_size = cute.size(cluster_shape_mnk, loc=loc, ip=ip) + + if not isinstance(cluster_size, int) or not isinstance(atom_sm_cnt, int): + raise ValueError( + f"Dynamic cluster shape or atom SM count is not supported: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if cute.size(cluster_shape_mnk, mode=[0], loc=loc, ip=ip) % atom_sm_cnt != 0: + raise ValueError( + f"Cluster shape not divisible by MMA size: {cluster_shape_mnk} and {atom_thr_id}" + ) + + if atom_sm_cnt == 2: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return CopyBulkTensorTileG2SMulticastOp(CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return CopyBulkTensorTileG2SOp(CtaGroup.ONE) + + raise ValueError( + f"Unsupported Configuration for SM100 TMA: {cluster_shape_mnk} and {atom_thr_id}" + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..fa1e2eb70e38236d73f435e001fdc160d301c47c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/blockscaled_layout.py @@ -0,0 +1,287 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass, field +from typing import Union + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass.cute as cute +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + + +@dataclass(frozen=True) +class BlockScaledBasicChunk: + """ + The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops. + + This class represents the fixed layout pattern for scale factors used in + tcgen05 BlockScaled MMA Ops. The layout is determined by the + instruction specification and cannot be modified. + See `PTX documentation `. + """ + + sf_vec_size: int + major_mode: OperandMajorMode = OperandMajorMode.K + _layout: cute.Layout = field(init=False, repr=False) + + def __post_init__(self) -> None: + if self.major_mode == OperandMajorMode.K: + # K-major layout: (AtomMN, AtomK) + atom_shape = ((32, 4), (self.sf_vec_size, 4)) + atom_stride = ((16, 4), (0, 1)) + else: + # MN-major layout: (AtomK, AtomMN) + atom_shape = ((self.sf_vec_size, 4), (32, 4)) + atom_stride = ((0, 1), (16, 4)) + + object.__setattr__( + self, "_layout", cute.make_layout(atom_shape, stride=atom_stride) + ) + + @property + def layout(self) -> cute.Layout: + """ + Get the layout for this block scaled chunk. + + :return: The layout representing the scale factor atom + :rtype: cute.Layout + """ + return self._layout + + +@dsl_user_op +def tile_atom_to_shape_SF( + Shape: cute.Shape, + sf_vec_size: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. + + :param Shape: The shape of the A/B tensor + :param sf_vec_size: Scale factor vector size + + :return: The layout of the SFA/SFB tensor + :rtype: cute.Layout + """ + # ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL) + sf_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3) + ) + return sf_layout + + +@dsl_user_op +def make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + Make smem layout for SFA based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (CTA_Tile_Shape_M, MMA_Tile_Shape_K) + sfa_tile_shape = ( + mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), + mma_tiler_mnk[2], + ) + + # ((Atom_M, Rest_M),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfa_tile_shape, + (2, 1), + ) + + mma_tile_inst_k = 4 + # (CTA_Tile_Shape_M, MMA_Inst_Shape_K) + sfa_tile_shape = cute.shape_div(sfa_tile_shape, (1, mma_tile_inst_k)) + # ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K)) + smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape) + + atom_m = 128 + tiler_inst = ((atom_m, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfa_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfa_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc=None, + ip=None, +) -> cute.Layout: + """ + Make smem layout for SFB based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K) + sfb_tile_shape = ( + cute.round_up(mma_tiler_mnk[1], 128), + mma_tiler_mnk[2], + ) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfb_tile_shape, + (2, 1), + ) + + mma_tile_inst_k = 4 + # (CTA_Tile_Shape_N, MMA_Inst_Shape_K) + sfb_tile_shape = cute.shape_div(sfb_tile_shape, (1, mma_tile_inst_k)) + # ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K) + smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape) + + atom_n = 128 + tiler_inst = ((atom_n, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfb_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfb_smem_layout_staged + + +@dsl_user_op +def make_tmem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc=None, + ip=None, +) -> cute.Layout: + """Make tmem layout for SFA based on: + 1. SFA smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFA per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFA + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + + sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip) + + +@dsl_user_op +def make_tmem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc=None, + ip=None, +) -> cute.Layout: + """Make tmem layout for SFB based on: + 1. SFB smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFB per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFB + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size + + sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5853c56c84f6fc02e911537147fa03b6b4566117 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/distributed_helpers.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from functools import partial +from typing import Tuple + +import cutlass.cute as cute +from cutlass.cutlass_dsl import T, dsl_user_op, while_generate + +from cutlass._mlir import ir +from cutlass._mlir.dialects import arith, llvm, nvvm, scf +from cutlass._mlir.dialects.nvvm import ( + MemOrderKind, + MemScopeKind, + AtomicOpKind, +) +from cutlass.cute.typing import Pointer, Int32, Boolean + + +@dsl_user_op +def atomicAdd(dst_ptr: Pointer, val: Int32, loc=None, ip=None) -> Int32: + return nvvm.atomicrmw( + T.i32(), + AtomicOpKind.ADD, + dst_ptr.llvm_ptr, + val.ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.RELAXED, + syncscope=MemScopeKind.SYS, + loc=loc, + ip=ip, + ) + + +@cute.jit +def ld_bypass(input_tensor: cute.Tensor): + fragment = cute.make_fragment(input_tensor.layout, input_tensor.element_type) + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + input_tensor.element_type, + memory_order=cute.nvgpu.common.MemoryOrder.VOLATILE, + memory_scope=cute.nvgpu.common.MemoryScope.SYS, + ) + cute.copy(copy_atom_load, input_tensor, fragment) + vals = fragment.load() + return vals + +@cute.jit +def spin_lock_wait(lock_ptr: Pointer, expect_count: Int32, mem_order : str = "relaxed", mem_scope : str = "gpu", loc=None, ip=None) -> None: + """ + wait on a spin lock until the expected count is reached. + """ + res = 0 + while res != expect_count: + res = nvvm.atomicrmw( + T.i32(), + AtomicOpKind.CAS, + lock_ptr.llvm_ptr, + Int32(0).ir_value(loc=loc, ip=ip), + b=Int32(expect_count).ir_value(loc=loc, ip=ip), + mem_order=MemOrderKind.ACQUIRE if mem_order == "acquire" else MemOrderKind.RELAXED, + syncscope=MemScopeKind.GPU if mem_scope == "gpu" else MemScopeKind.SYS + ) + + +@dsl_user_op +def multimem_red_add_sys_release(mc_ptr: Pointer, loc=None, ip=None) -> None: + """ + add 1 to the multimem address + """ + llvm.inline_asm( + None, + [mc_ptr.toint().ir_value()], + "multimem.red.release.sys.global.add.u32 [$0], 1;", + "l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + +@dsl_user_op +def multimem_red_add_gpu_relaxed(mc_ptr: Pointer, loc=None, ip=None) -> None: + """ + add 1 to the multimem address + """ + llvm.inline_asm( + None, + [mc_ptr.toint().ir_value()], + "multimem.red.relaxed.gpu.global.add.u32 [$0], 1;", + "l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + + +def spin_lock_multimem_arrive(lock_ptr: Pointer, loc=None, ip=None) -> None: + """ + arrive a spin lock when the lock_ptr is a multimem address. + """ + multimem_red_add_gpu_relaxed(lock_ptr, loc=loc, ip=ip) + + +def sm_wise_inter_gpu_multimem_barrier(barrier : Pointer, barrier_mc : Pointer, num_ranks, loc=None, ip=None) -> None : + """ + barrier for inter-gpu sm-wise + """ + bidx, bidy, bidz = cute.arch.block_idx() + bdimx, bdimy, _ = cute.arch.grid_dim() + pid = bidx + bidy * bdimx + bidz * bdimx * bdimy + multimem_red_add_sys_release(barrier_mc + pid, loc=loc, ip=ip) + cute.arch.fence_proxy(cute.arch.ProxyKind.alias) + spin_lock_wait(barrier + pid, num_ranks, mem_order="acquire", mem_scope="sys", loc=loc, ip=ip) + + +@dsl_user_op +def multimem_ld_reduce_base( + mc_ptr: Pointer, + *, + ptx_string: str = "", + loc=None, + ip=None, +) -> Tuple[Int32, Int32, Int32, Int32]: + # ld reduce 8xf16 elts + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + return_struct = llvm.inline_asm( + ir.Type.parse("!llvm.struct<(i32,i32,i32,i32)>"), + [mc_ptr_int], + ptx_string, + "=r,=r,=r,=r,l", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + return_regs = [llvm.extractvalue(T.i32(), return_struct, [i]) for i in range(4)] + return return_regs[0], return_regs[1], return_regs[2], return_regs[3] + + +multimem_ld_reduce_8xf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.f16x2 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_4xf32 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.v4.f32 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_8xbf16 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_16xe4m3 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e4m3x4 {$0, $1, $2, $3}, [$4];") +multimem_ld_reduce_16xe5m2 = partial(multimem_ld_reduce_base, ptx_string="multimem.ld_reduce.sys.relaxed.global.add.acc::f16.v4.e5m2x4 {$0, $1, $2, $3}, [$4];") + + +@dsl_user_op +def multimem_st_4xb32( + mc_ptr: Pointer, + x: Int32, + y: Int32, + z: Int32, + w: Int32, + *, + loc=None, + ip=None, +) -> None: + # st 4x32 bits of data + mc_ptr_int = mc_ptr.toint(loc=loc, ip=ip).ir_value() + llvm.inline_asm( + T.i32(), + [mc_ptr_int, x, y, z, w], + "multimem.st.sys.relaxed.global.v4.f32 [$1], {$2, $3, $4, $5};", + "=r,l,r,r,r,r", + has_side_effects=True, + asm_dialect=0, + loc=loc, + ip=ip, + ) + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..a51bae62963bd482fd590f824a4bc1c8564ece0e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/grouped_gemm_tile_scheduler_helper.py @@ -0,0 +1,466 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import List, Tuple + +import cutlass.cute as cute +from cutlass.cutlass_dsl import Int32, extract_mlir_values, new_from_mlir_values +from cutlass._mlir import ir + +from cutlass.utils.static_persistent_tile_scheduler import PersistentTileSchedulerParams + + +class GroupSearchResult: + """ + The result of the group search for grouped gemm. + + :param group_idx: The result group index + :type group_idx: Int32 + :param cta_tile_idx_m: CTA tile index along M dimension after rasterization + :type cta_tile_idx_m: Int32 + :param cta_tile_idx_n: CTA tile index along N dimension after rasterization + :type cta_tile_idx_n: Int32 + :param problem_shape_m: The M dimension of the gemm problem + :type problem_shape_m: Int32 + :param problem_shape_n: The N dimension of the gemm problem + :type problem_shape_n: Int32 + :param problem_shape_k: The K dimension of the gemm problem + :type problem_shape_k: Int32 + :param cta_tile_count_k: Number of tiles along K dimension + :type cta_tile_count_k: Int32 + """ + + def __init__( + self, + group_idx: Int32, + cta_tile_idx_m: Int32, + cta_tile_idx_n: Int32, + problem_shape_m: Int32, + problem_shape_n: Int32, + problem_shape_k: Int32, + cta_tile_count_k: Int32, + ) -> None: + self.group_idx = group_idx + self.cta_tile_idx_m = cta_tile_idx_m + self.cta_tile_idx_n = cta_tile_idx_n + self.problem_shape_m = problem_shape_m + self.problem_shape_n = problem_shape_n + self.problem_shape_k = problem_shape_k + self.cta_tile_count_k = cta_tile_count_k + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.group_idx) + values.extend(extract_mlir_values(self.cta_tile_idx_m)) + values.extend(extract_mlir_values(self.cta_tile_idx_n)) + values.extend(extract_mlir_values(self.problem_shape_m)) + values.extend(extract_mlir_values(self.problem_shape_n)) + values.extend(extract_mlir_values(self.problem_shape_k)) + values.extend(extract_mlir_values(self.cta_tile_count_k)) + return values + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "GroupSearchResult": + assert len(values) == 7 + return GroupSearchResult(*tuple(values)) + + +class GroupedGemmGroupSearchState: + """ + The state of group index search for grouped gemm. + + The state will be initialized once and updated in every round of group index search. + + :param start_group_idx: The group idx to start the search with + :type start_group_idx: Int32 + :param tile_count_prev_group: Number of tiles before the matched group + :type tile_count_prev_group: Int32 + :param tile_count_searched: Number of tiles we have searched. When the matched group is found, + it records the number of tiles including the matched group + :type tile_count_searched: Int32 + """ + + def __init__( + self, + start_group_idx: Int32, + tile_count_prev_group: Int32, + tile_count_searched: Int32, + ) -> None: + self.start_group_idx = start_group_idx + self.tile_count_prev_group = tile_count_prev_group + self.tile_count_searched = tile_count_searched + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.start_group_idx) + values.extend(extract_mlir_values(self.tile_count_prev_group)) + values.extend(extract_mlir_values(self.tile_count_searched)) + return values + + def __new_from_mlir_values__( + self, values: List[ir.Value] + ) -> "GroupedGemmGroupSearchState": + start_group_idx = new_from_mlir_values(self.start_group_idx, [values[0]]) + tile_count_prev_group = new_from_mlir_values( + self.tile_count_prev_group, [values[1]] + ) + tile_count_searched = new_from_mlir_values( + self.tile_count_searched, [values[2]] + ) + return GroupedGemmGroupSearchState( + start_group_idx, tile_count_prev_group, tile_count_searched + ) + + +def create_initial_search_state() -> GroupedGemmGroupSearchState: + """ + Create an initial search state for grouped gemm. + + :return: A new search state with initial values + :rtype: GroupedGemmGroupSearchState + """ + return GroupedGemmGroupSearchState( + start_group_idx=Int32(0), + tile_count_prev_group=Int32(0), + tile_count_searched=Int32(0), + ) + + +class GroupedGemmTileSchedulerHelper: + """ + A helper to translate the raw block index (x, y, z) from tile scheduler to real CTA tile index for grouped gemm. + + :param group_count: Number of groups in current grouped gemm problem + :type group_count: int + :param tile_sched_params: Parameter used to create the tile scheduler this helper works with + :type tile_sched_params: PersistentTileSchedulerParams + :param cluster_tile_shape_mnk: The shape of cluster tile as (m, n, k) + :type cluster_tile_shape_mnk: tuple[int, int, int] + :param search_state: The initial search state + :type search_state: GroupedGemmGroupSearchState + """ + + def __init__( + self, + group_count: int, + tile_sched_params: PersistentTileSchedulerParams, + cluster_tile_shape_mnk: tuple[int, int, int], + search_state: GroupedGemmGroupSearchState, + ) -> None: + self.tile_sched_params = tile_sched_params + self.group_count = group_count + self.lane_idx = cute.arch.lane_idx() + self.cluster_tile_shape_mnk = cluster_tile_shape_mnk + self.search_state = search_state + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.tile_sched_params) + values.extend(extract_mlir_values(self.search_state)) + return values + + def __new_from_mlir_values__( + self, values: List[ir.Value] + ) -> "GroupedGemmTileSchedulerHelper": + tile_sched_params = new_from_mlir_values(self.tile_sched_params, values) + search_state = new_from_mlir_values(self.search_state, values[1:]) + return GroupedGemmTileSchedulerHelper( + self.group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + search_state, + ) + + def delinearize_z( + self, + cta_tile_coord: tuple, + problem_shape_mnkl: cute.Tensor, + ) -> GroupSearchResult: + """ + Delinearize the linear z index and return GroupSearchResult. + + This function should be used by warps that need to know the CTA tile index on M and N dimensions. + + :param cta_tile_coord: The raw CTA coordinate from tile scheduler + :type cta_tile_coord: tuple of Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for each group + :type problem_shape_mnkl: cute.Tensor + :return: The search result containing group index and tile coordinates + :rtype: GroupSearchResult + """ + # delinear the z coord + linear_idx = cta_tile_coord[2] + group_idx, problem_mnkl = self._group_search_and_load_problem_shape( + linear_idx, + problem_shape_mnkl, + self.search_state.start_group_idx, + self.search_state.tile_count_prev_group, + ) + # linear index local to current group + cluster_tile_idx_in_current_group = ( + linear_idx - self.search_state.tile_count_prev_group + ) + cluster_count_m, cluster_count_n, cluster_count_k = cute.ceil_div( + (problem_mnkl[0], problem_mnkl[1], problem_mnkl[2]), + ( + self.cluster_tile_shape_mnk[0], + self.cluster_tile_shape_mnk[1], + self.cluster_tile_shape_mnk[2], + ), + ) + # decompose to get indices on M and N + cta_tile_idx_m, cta_tile_idx_n = self._compute_cta_tile_coord( + cluster_tile_idx_in_current_group, + cta_tile_coord, + cluster_count_m, + cluster_count_n, + ) + return GroupSearchResult( + group_idx, + cta_tile_idx_m, + cta_tile_idx_n, + problem_mnkl[0], + problem_mnkl[1], + problem_mnkl[2], + cluster_count_k, + ) + + def search_cluster_tile_count_k( + self, + cta_tile_coord: tuple, + problem_shape_mnkl: cute.Tensor, + ) -> Tuple[Int32, Int32]: + """ + Search the matched group for given linear index and compute the number of tiles along K dimension for the matched group. + + This function should be used by warps that are only interested in the number of tiles along K dimension. + + :param cta_tile_coord: The raw CTA coordinate from tile scheduler + :type cta_tile_coord: tuple of Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups + :type problem_shape_mnkl: cute.Tensor + :return: A tuple containing cluster count along K dimension and the group index + :rtype: Tuple[Int32, Int32] + """ + group_idx, problem_mnk = self._group_search_and_load_problem_shape( + cta_tile_coord[2], + problem_shape_mnkl, + self.search_state.start_group_idx, + self.search_state.tile_count_prev_group, + ) + cluster_count_k = ( + problem_mnk[2] + self.cluster_tile_shape_mnk[2] - 1 + ) // self.cluster_tile_shape_mnk[2] + return cluster_count_k, group_idx + + @cute.jit + def _prefix_sum(self, value_per_thread: Int32) -> Int32: + """ + Perform prefix sum within a full warp. + + :param value_per_thread: The value for this thread to contribute to the prefix sum + :type value_per_thread: Int32 + :return: The prefix sum result for this thread + :rtype: Int32 + """ + clamp_value = 0 + idx = 1 + sum_per_thread = value_per_thread + while idx < cute.arch.WARP_SIZE: + value = cute.arch.shuffle_sync_up( + sum_per_thread, idx, mask_and_clamp=clamp_value + ) + if self.lane_idx >= idx: + sum_per_thread += value + idx = idx << 1 + return sum_per_thread + + def _get_problem_for_group( + self, problem_shape_mnkl: cute.Tensor, group_idx: Int32 + ) -> cute.Tensor: + """ + Load gemm problem (m,n,k,l) for the specified group from global memory to register. + + :param problem_shape_mnkl: Tensor in global memory with layout (group_count, 4):(4, 1) + :type problem_shape_mnkl: cute.Tensor + :param group_idx: The index of the group to load + :type group_idx: Int32 + :return: The problem shape tensor for the specified group + :rtype: cute.Tensor + """ + cur_problem_mnkl = cute.make_fragment( + cute.make_layout(4), problem_shape_mnkl.element_type + ) + cute.autovec_copy(problem_shape_mnkl[(group_idx, None)], cur_problem_mnkl) + return cur_problem_mnkl + + def _get_cluster_tile_count_mn(self, problem_shape: cute.Tensor) -> Int32: + """ + Compute total cluster count. + + :param problem_shape: Tensor containing problem shape (m, n, k, l) + :type problem_shape: cute.Tensor + :return: The total cluster tile count for M and N dimensions + :rtype: Int32 + """ + cur_ntile_m = ( + problem_shape[0] + self.cluster_tile_shape_mnk[0] - 1 + ) // self.cluster_tile_shape_mnk[0] + cur_ntile_n = ( + problem_shape[1] + self.cluster_tile_shape_mnk[1] - 1 + ) // self.cluster_tile_shape_mnk[1] + cur_ntile_mn = cur_ntile_m * cur_ntile_n + return cur_ntile_mn + + def _compute_cta_tile_coord( + self, + cluster_tile_idx: Int32, + cta_tile_coord_in_cluster: tuple, + cluster_tile_count_m: Int32, + cluster_tile_count_n: Int32, + ) -> tuple: + """ + Compute CTA tile indices along M and N dimensions based on the linear index within a group. + + It uses the AlongM mode to decompose the linear index onto M and N dimensions. + + :param cluster_tile_idx: The linear index within a group + :type cluster_tile_idx: Int32 + :param cta_tile_coord_in_cluster: CTA indices along M and N dimensions within a cluster + :type cta_tile_coord_in_cluster: tuple of Int32 + :param cluster_tile_count_m: The number of clusters along M dimension of the matched group + :type cluster_tile_count_m: Int32 + :param cluster_tile_count_n: The number of clusters along N dimension of the matched group + :type cluster_tile_count_n: Int32 + :return: A tuple containing CTA tile indices along M and N dimensions + :rtype: tuple of (Int32, Int32) + """ + cluster_layout_mn = cute.make_layout( + (cluster_tile_count_m, cluster_tile_count_n) + ) + (mi, ni) = cluster_layout_mn.get_hier_coord(cluster_tile_idx) + cta_tile_idx_m = ( + mi * self.tile_sched_params.cluster_shape_mn[0] + + cta_tile_coord_in_cluster[0] + ) + cta_tile_idx_n = ( + ni * self.tile_sched_params.cluster_shape_mn[1] + + cta_tile_coord_in_cluster[1] + ) + return (cta_tile_idx_m, cta_tile_idx_n) + + @cute.jit + def _group_search( + self, + linear_idx: Int32, + problem_shape_mnkl: cute.Tensor, + init_group_idx: Int32, + init_tile_count_searched: Int32, + ) -> GroupedGemmGroupSearchState: + """ + Search which group the linear index belongs to. + + :param linear_idx: The linear index to be decomposed + :type linear_idx: Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups + :type problem_shape_mnkl: cute.Tensor + :param init_group_idx: The group idx to start the search with + :type init_group_idx: Int32 + :param init_tile_count_searched: The number of tiles we have searched + :type init_tile_count_searched: Int32 + :return: The updated search state + :rtype: GroupedGemmGroupSearchState + """ + c_0 = Int32(0).ir_value() + last_lane_idx = cute.arch.WARP_SIZE - 1 + + tile_count_searched = init_tile_count_searched + start_group_idx = init_group_idx + not_found = linear_idx >= tile_count_searched + tile_count_prev_group = self.search_state.tile_count_prev_group + while not_found: + # get group to search for current lane + cur_group_idx = start_group_idx + self.lane_idx + # check if the group to be checked is out of range + inside_group_bound = cur_group_idx < self.group_count + cur_ntile_mn = c_0 + if inside_group_bound: + # get problem size of current group + cur_problem_mnkl = self._get_problem_for_group( + problem_shape_mnkl, cur_group_idx + ) + cur_ntile_mn = self._get_cluster_tile_count_mn(cur_problem_mnkl) + # compute tile count from beginning to current group(included) + total_cluster_tile_count_ps_per_thread = self._prefix_sum(cur_ntile_mn) + cluster_tile_count_end_per_thread = ( + total_cluster_tile_count_ps_per_thread + tile_count_searched + ) + + group_not_in_window = linear_idx >= cluster_tile_count_end_per_thread + hitted_group_idx_in_search_window = cute.arch.popc( + cute.arch.vote_ballot_sync(group_not_in_window) + ) + not_found = hitted_group_idx_in_search_window == cute.arch.WARP_SIZE + start_group_idx = hitted_group_idx_in_search_window + start_group_idx + hit_the_1st_problem_in_search_window = ( + hitted_group_idx_in_search_window == c_0 + ) + tile_count_prev_group = tile_count_searched + if hit_the_1st_problem_in_search_window == False: + tile_count_prev_group = cute.arch.shuffle_sync( + cluster_tile_count_end_per_thread, + hitted_group_idx_in_search_window - 1, + ) + + # If no matched group, then get new_cluster_tile_count_end from last lane + # Otherwise, get new_cluster_tile_count_end from the hitted group + lane_idx_for_cluster_tile_count_end = hitted_group_idx_in_search_window + if not_found: + lane_idx_for_cluster_tile_count_end = last_lane_idx + tile_count_searched = cute.arch.shuffle_sync( + cluster_tile_count_end_per_thread, + lane_idx_for_cluster_tile_count_end, + ) + + return GroupedGemmGroupSearchState( + start_group_idx, + tile_count_prev_group, + tile_count_searched, + ) + + def _group_search_and_load_problem_shape( + self, + linear_idx: Int32, + problem_shape_mnkl: cute.Tensor, + start_group_idx: Int32, + tile_count_searched: Int32, + ) -> Tuple[Int32, cute.Tensor]: + """ + Perform group search and load problem shape for the matched group. + + :param linear_idx: The linear index to be decomposed + :type linear_idx: Int32 + :param problem_shape_mnkl: Tensor containing gemm problem size (M, N, K, L) for all groups + :type problem_shape_mnkl: cute.Tensor + :param start_group_idx: The group idx to start the search with + :type start_group_idx: Int32 + :param tile_count_searched: The number of tiles we have searched + :type tile_count_searched: Int32 + :return: A tuple containing the final group index and the problem shape tensor + :rtype: Tuple[Int32, cute.Tensor] + """ + self.search_state = self._group_search( + linear_idx, + problem_shape_mnkl, + start_group_idx, + tile_count_searched, + ) + # get final group search state + final_group_idx = self.search_state.start_group_idx + # let's revisit if it's better to broadcast problem_shape_mnk in group_search + problem_mnkl = self._get_problem_for_group(problem_shape_mnkl, final_group_idx) + return final_group_idx, problem_mnkl diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py new file mode 100644 index 0000000000000000000000000000000000000000..e86fcbefc86fbc7da333735fa2cebbd3af47f39e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hardware_info.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from cuda.bindings import driver, nvrtc + +import cutlass.cute as cute + +""" +This class is used to get the hardware info of given GPU device. +It provides methods to get the max active clusters for given cluster size. + +Prerequisite: +- CUDA driver is initialized via `driver.cuInit` or other CUDA APIs. +- CUDA context is created via `driver.cuCtxCreate` or other CUDA APIs. + +""" + + +class HardwareInfo: + """ + device_id: CUDA device ID to get the hardware info. + """ + + def __init__(self, device_id: int = 0): + count = self._checkCudaErrors(driver.cuDeviceGetCount()) + if device_id >= count: + raise ValueError( + f"Device ID {device_id} is out of range for device count {count}" + ) + self.device_id = device_id + self.device = self._checkCudaErrors(driver.cuDeviceGet(device_id)) + self.context = self._checkCudaErrors(driver.cuCtxGetCurrent()) + self.driver_version = self._checkCudaErrors(driver.cuDriverGetVersion()) + + # Getting the max active clusters for a given cluster size + def get_max_active_clusters(self, cluster_size: int) -> int: + self._get_device_function() + if self._cuda_driver_version_lt(11, 8): + raise RuntimeError( + "CUDA Driver version < 11.8, cannot get _max_active_clusters" + ) + if cluster_size <= 0 or cluster_size > 32: + raise ValueError( + f"Cluster size must be between 1 and 32, {cluster_size} is not supported" + ) + + max_shared_memory_per_block = self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + self.device, + ) + ) + self._checkCudaErrors( + driver.cuFuncSetAttribute( + self.kernel, + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + max_shared_memory_per_block, + ) + ) + max_dynamic_shared_memory = self._checkCudaErrors( + driver.cuOccupancyAvailableDynamicSMemPerBlock( + self.kernel, 1, 1 # numBlocks # blockSize + ) + ) + max_active_blocks = self._checkCudaErrors( + driver.cuOccupancyMaxActiveBlocksPerMultiprocessor( + self.kernel, 1, max_dynamic_shared_memory # blockSize, + ) + ) + # allow non-portable cluster size to support detection of non-portable cluster size + self._checkCudaErrors( + driver.cuFuncSetAttribute( + self.kernel, + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, + 1, + ) + ) + # prepare launch configuration + launch_config = driver.CUlaunchConfig() + launch_config.blockDimX = 128 + launch_config.blockDimY = 1 + launch_config.blockDimZ = 1 + launch_config.sharedMemBytes = max_dynamic_shared_memory + launch_config.numAttrs = 1 + # max possible cluster size is 32 + cluster_dims_attr = driver.CUlaunchAttribute() + cluster_dims_attr.id = ( + driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + ) + value = driver.CUlaunchAttributeValue() + value.clusterDim.x = cluster_size + value.clusterDim.y = 1 + value.clusterDim.z = 1 + cluster_dims_attr.value = value + launch_config.attrs = [cluster_dims_attr] + launch_config.gridDimX = cluster_size + launch_config.gridDimY = max_active_blocks + launch_config.gridDimZ = 1 + + num_clusters = self._checkCudaErrors( + driver.cuOccupancyMaxActiveClusters(self.kernel, launch_config) + ) + return num_clusters + + def get_l2_cache_size_in_bytes(self) -> int: + return self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, + self.device, + ) + ) + + def get_device_multiprocessor_count(self) -> int: + return self._checkCudaErrors( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, + self.device, + ) + ) + + def _checkCudaErrors(self, result) -> None: + if result[0].value: + raise RuntimeError( + "CUDA error code={}({})".format( + result[0].value, self._cudaGetErrorEnum(result[0]) + ) + ) + # CUDA APIs always return the status as the first element of the result tuple + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + def _cudaGetErrorEnum(self, error) -> str: + if isinstance(error, driver.CUresult): + err, name = driver.cuGetErrorName(error) + return name if err == driver.CUresult.CUDA_SUCCESS else "" + elif isinstance(error, nvrtc.nvrtcResult): + return nvrtc.nvrtcGetErrorString(error)[1] + else: + raise RuntimeError("Unknown error type: {}".format(error)) + + def _cuda_driver_version_ge(self, major: int, minor: int) -> bool: + return self.driver_version >= (major * 1000 + 10 * minor) + + def _cuda_driver_version_lt(self, major: int, minor: int) -> bool: + return not self._cuda_driver_version_ge(major, minor) + + @cute.kernel + def _empty_kernel(self): + return + + @cute.jit + def _host_function(self): + self._empty_kernel().launch( + grid=[1, 1, 1], + block=[1, 1, 1], + ) + + # get a empty kernel to compute occupancy + def _get_device_function(self) -> None: + self.compiled_kernel = cute.compile(self._host_function) + self.module = next(iter(self.compiled_kernel.cuda_modules.modules)).cuda_module + self.kernel = next(iter(self.compiled_kernel.cuda_modules.modules)).kernel_ptr diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd2bae3de66983dc5bf7883305f6a926b3c0d72 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/hopper_helpers.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type, Tuple +from enum import Enum +from typing_extensions import deprecated +import warnings + +from cutlass.utils.layout import LayoutEnum +from cutlass.cutlass_dsl import ( + Float16, + BFloat16, + Float8E5M2, + Float8E4M3FN, + Numeric, + NumericMeta, + dsl_user_op, +) + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu.common import CopyUniversalOp +from cutlass.cute.nvgpu.warp import StMatrix8x8x16bOp +from cutlass.cute.nvgpu.warpgroup import ( + MmaF16BF16Op, + MmaF8Op, + OperandMajorMode, + OperandSource, +) + + +@deprecated("Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead") +class SmemCapacity(Enum): + SM90_SMEM_CAPACITY_BYTES = (228 - 1) * 1024 + + +warnings.warn( + "SMEM_CAPACITY is deprecated: Use get_smem_capacity_in_bytes from cutlass.utils.smem_capacity instead", + DeprecationWarning, + stacklevel=2, +) +# Dictionary to map compute capability to SMEM capacity +SMEM_CAPACITY = { + "sm90": SmemCapacity.SM90_SMEM_CAPACITY_BYTES.value, +} + + +@dsl_user_op +def sm90_get_smem_store_op( + layout_d: LayoutEnum, + elem_ty_d: Type[Numeric], + elem_ty_acc: Type[Numeric], + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem store atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_d : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_d : Type[Numeric] + The element type for output tensor D. + + elem_ty_acc : Type[Numeric] + The element type for accumulator. + + Returns: + -------- + Either SmemStoreMatrix or SimtSyncCopy, based on the input parameters. + """ + + def validate_type(ty, ty_name): + if not isinstance(ty, NumericMeta): + raise TypeError(f"{ty_name} must be a Numeric, but got {ty}") + + validate_type(elem_ty_d, "elem_ty_d") + validate_type(elem_ty_acc, "elem_ty_acc") + + is_m_major = layout_d.is_m_major_c() + + if elem_ty_d.width == 16: + return cute.make_copy_atom( + StMatrix8x8x16bOp(is_m_major, 4), elem_ty_d, loc=loc, ip=ip + ) + else: + return cute.make_copy_atom(CopyUniversalOp(), elem_ty_d, loc=loc, ip=ip) + + +def make_trivial_tiled_mma( + a_dtype: Type[Numeric], + b_dtype: Type[Numeric], + a_leading_mode: OperandMajorMode, + b_leading_mode: OperandMajorMode, + acc_dtype: Type[Numeric], + atom_layout_mnk: Tuple[int, int, int], + tiler_mn: Tuple[int, int], + a_source: OperandSource = OperandSource.SMEM, + *, + loc=None, + ip=None, +) -> cute.TiledMma: + """Make a tiled MMA atom with given data type, leading dimension, cta group and mma tile shape. + By default, the MMA atom is created with SMEM operand source for A. + + :param a_dtype: Data type of operand A. + :type a_dtype: type[Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[Numeric] + :param a_leading_mode: Leading dimension of operand A (1 for K, 0 for M/N). + :type a_leading_mode: warpgroup.OperandMajorMode + :param b_leading_mode: Leading dimension of operand B (1 for K, 0 for M/N). + :type b_leading_mode: warpgroup.OperandMajorMode + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[Numeric] + :param atom_layout_mnk: A integer tuple describing the tiling of Atom across threads. + :type atom_layout_mnk: Tuple[int, int, int] + :param tiler_mn: The shape (M, N) of the cta tiler. + :type tiler_mn: Tuple[int, int] + + :return: A tiled MMA atom. + :rtype: cute.TiledMma + + :raises TypeError: If the data type is not supported. + """ + + if a_dtype in {Float16, BFloat16}: + if cutlass.const_expr(a_dtype != b_dtype): + raise TypeError(f"Type mismatch: {a_dtype} != {b_dtype}") + if cutlass.const_expr(a_dtype.width != b_dtype.width): + raise TypeError(f"Type width mismatch: {a_dtype.width} != {b_dtype.width}") + + mma_op = MmaF16BF16Op( + a_dtype, + acc_dtype, + (*tiler_mn, 16), + a_source, + a_leading_mode, + b_leading_mode, + ) + elif a_dtype in {Float8E4M3FN, Float8E5M2} and b_dtype in { + Float8E4M3FN, + Float8E5M2, + }: + mma_op = MmaF8Op( + a_dtype, + b_dtype, + acc_dtype, + (*tiler_mn, 32), + a_source, + a_leading_mode, + b_leading_mode, + ) + else: + raise TypeError(f"unsupported a_dtype and b_dtype, got {a_dtype} and {b_dtype}") + + return cute.make_tiled_mma(cute.make_mma_atom(mma_op), atom_layout_mnk) + +def get_smem_layout_atom( + layout: LayoutEnum, + element_type: Type[Numeric], + major_mode_size: int, + *, + loc=None, + ip=None, +): + """Select the optimal shared memory layout atom based on parameters. + + :param layout: Layout enum of the tensor + :type layout: LayoutEnum + :param element_type: Data type of the elements + :type element_type: type[cutlass.Numeric] + :param major_mode_size: Size of the major mode dimension + :type major_mode_size: int + + :return: Selected shared memory layout atom kind + :rtype: cute.nvgpu.warpgroup.SmemLayoutAtomKind + """ + assert major_mode_size % 8 == 0 + sw128_num_contiguous_bits = 1024 + sw64_num_contiguous_bits = 512 + sw32_num_contiguous_bits = 256 + major_mode_size_bits = major_mode_size * element_type.width + if layout.sm90_mma_major_mode() == OperandMajorMode.MN: + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_SW32 + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.MN_INTER + if major_mode_size_bits % sw128_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW128 + if major_mode_size_bits % sw64_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW64 + if major_mode_size_bits % sw32_num_contiguous_bits == 0: + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_SW32 + return cute.nvgpu.warpgroup.SmemLayoutAtomKind.K_INTER diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..4560c266cf9930ac024adeaa94859d06ecf3650a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/layout.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from enum import Enum + +import cutlass.cute as cute +from cutlass.cute.nvgpu import warpgroup +from cutlass.cute.nvgpu import tcgen05 + + +class LayoutEnum(Enum): + ROW_MAJOR = "row_major" + COL_MAJOR = "col_major" + + def mma_major_mode(self): + return ( + tcgen05.OperandMajorMode.K + if self == LayoutEnum.ROW_MAJOR + else tcgen05.OperandMajorMode.MN + ) + + def sm90_mma_major_mode(self): + return ( + warpgroup.OperandMajorMode.K + if self == LayoutEnum.ROW_MAJOR + else warpgroup.OperandMajorMode.MN + ) + + def is_n_major_c(self): + return self == LayoutEnum.ROW_MAJOR + + def is_m_major_c(self): + return self == LayoutEnum.COL_MAJOR + + @staticmethod + def from_tensor(tensor: cute.Tensor) -> "LayoutEnum": + ret = None + if tensor.leading_dim == 1: + ret = LayoutEnum.ROW_MAJOR + elif tensor.leading_dim == 0: + ret = LayoutEnum.COL_MAJOR + else: + raise ValueError(f"Invalid leading dimension: {tensor.leading_dim}") + + return ret + + +__all__ = ["LayoutEnum"] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py new file mode 100644 index 0000000000000000000000000000000000000000..2500c06e1808bc06db5decce88e8ebf7837f17d0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_allocator.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Type, Union, overload + +from cutlass.cutlass_dsl import Int8, Numeric, NumericMeta, CutlassBaseDSL + +import cutlass.cute as cute +from cutlass.cute.arch import get_dyn_smem, get_dyn_smem_size + + +class SmemAllocator: + """A class for managing shared memory allocation on GPU. + + This class manages a chunk of shared memory and provides APIs for sub-allocation + inside the chunk. + + :ivar _base: The current base address of the shared memory as an i8 typed dynamic value. + :type _base: cute.Pointer + :ivar _allocated_bytes: The total number of bytes allocated in shared memory. + :type _allocated_bytes: int + + .. note:: + This class is responsible for managing the allocation of tensors in shared memory. + The base pointer is aligned to 1024 bytes upon initialization. + """ + + def __init__(self): + """Initialize the SmemAllocator instance. + + Creates a dynamic shared memory base pointer of type i8, aligned to 1024 bytes. + """ + self._base = get_dyn_smem(Int8, alignment=1024) + self._allocated_bytes = 0 + CutlassBaseDSL.track_smem_allocator(self, lambda cls: cls._allocated_bytes) + + @overload + def allocate(self, size_or_type: int, byte_alignment: int) -> cute.Pointer: ... + + @overload + def allocate( + self, size_or_type: cute.struct, byte_alignment: int + ) -> cute.Pointer: ... + + def allocate(self, size_or_type, byte_alignment: int = 1) -> cute.Pointer: + """Allocate a block of memory with specified size and alignment. + + This method adjusts the base pointer to ensure proper alignment and updates + the internal state to track allocated memory. + + :param size_or_type: The number of bytes to allocate or a struct class + :type size_or_type: Union[int, cute.struct] + :param byte_alignment: The byte alignment requirement, defaults to 1 (no alignment) + :type byte_alignment: int, optional + :return: Pointer to the start of the allocated memory block or struct instance + :rtype: cute.Pointer + :raises ValueError: If size is negative or alignment is less than 1 + :raises RuntimeError: If allocation would exceed available shared memory + """ + if isinstance(size_or_type, cute.struct): + alignment = max(byte_alignment, size_or_type.__alignof__()) + base_ptr = self.allocate(size_or_type.__sizeof__(), alignment) + return size_or_type(base_ptr) + + num_bytes = size_or_type + if num_bytes < 0: + raise ValueError("num_bytes must be non-negative") + if byte_alignment < 1: + raise ValueError("byte_alignment must be at least 1") + + self._base = self._base.align(byte_alignment) + ptr = self._base + self._base += num_bytes + if self._allocated_bytes % byte_alignment != 0: + self._allocated_bytes += ( + byte_alignment - self._allocated_bytes % byte_alignment + ) + self._allocated_bytes += num_bytes + + # Check bounds against available dynamic shared memory + cute.testing.assert_( + self._allocated_bytes <= get_dyn_smem_size(), + f"Allocation failed: shared memory allocation exceeds available memory set in kernel launch. " + f"Allocated bytes: {self._allocated_bytes} bytes. " + f"Please reduce the allocation or set a larger smem size in kernel launch.", + ) + return ptr + + def allocate_array(self, element_type: Type[Numeric], num_elems: int = 1): + """Allocate an array of elements in shared memory. + + :param element_type: The type of elements to allocate + :type element_type: Type[Numeric] + :param num_elems: Number of elements to allocate, defaults to 1 + :type num_elems: int, optional + :return: Pointer to the start of the allocated array + :rtype: cute.Pointer + :raises ValueError: If num_elems is less than 1 + :raises TypeError: If element_type is not a Numeric type + """ + if num_elems < 1: + raise ValueError("num_elems must be at least 1") + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"value_ty must be a type of Numeric, but got {element_type}" + ) + + ptr = self.allocate( + element_type.width // 8 * num_elems, element_type.width // 8 + ) + + return cute.recast_ptr(ptr, dtype=element_type) + + def allocate_tensor( + self, + element_type: Type[Numeric], + layout: Union[int, cute.Layout, cute.ComposedLayout], + byte_alignment: int = 1, + swizzle: cute.Swizzle = None, + ): + """Allocate a tensor in shared memory. + + :param element_type: The type of elements in the tensor + :type element_type: Type[Numeric] + :param layout: The layout specification for the tensor + :type layout: Union[int, cute.Layout, cute.ComposedLayout] + :param byte_alignment: The byte alignment requirement, defaults to 1 + :type byte_alignment: int, optional + :param swizzle: Swizzle for position-dependent swizzling, defaults to None + :type swizzle: cute.Swizzle, optional + :return: The allocated tensor with specified properties + :rtype: cute.Tensor + :raises TypeError: If element_type is not a Numeric type or if swizzle conflicts with layout + :raises ValueError: If allocation is not byte-aligned + :raises NotImplementedError: If dynamic layout is specified + """ + if not isinstance(element_type, NumericMeta): + raise TypeError( + f"value_ty must be a type of Numeric, but got {element_type}" + ) + + if ( + isinstance(layout, cute.ComposedLayout) + and isinstance(layout.inner, cute.Swizzle) + ) and (swizzle is not None): + raise TypeError( + f"Invalid tensor type: cannot be both iterator swizzle (PDSL) and swizzle layout(PISL) at the same time." + ) + + if isinstance(layout, int): + layout = cute.make_layout(layout) + + profile = layout(0) + if isinstance(profile, tuple): + raise TypeError( + f"cannot allocate a shared memory tensor with a non-integer iterator" + ) + + if not cute.is_static(layout.type): + raise NotImplementedError(f"dynamic layout is not supported: {layout.type}") + + # At least align the allocation to the natural alignment given by the element type + if element_type.width // 8 > byte_alignment: + byte_alignment = element_type.width // 8 + + # Relevant only for sub-byte data types: verify that the entire allocation is byte-aligned + cosize_in_bits = cute.cosize(layout) * element_type.width + assert isinstance(cosize_in_bits, int) + if cosize_in_bits % 8 != 0: + raise ValueError("invalid allocation that is not byte-aligned") + + num_bytes = cosize_in_bits // 8 + ptr = self.allocate(num_bytes, byte_alignment) + ptr = cute.recast_ptr(ptr, swizzle, dtype=element_type) + res = cute.make_tensor(ptr, layout) + return res diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py new file mode 100644 index 0000000000000000000000000000000000000000..87ddb990436caf8135a849b3a37bf52632eed2fc --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/smem_capacity.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + + +SMEM_CAPACITY_MAP = { + "sm_120": (100 - 1) * 1024, + "sm_100": (228 - 1) * 1024, + "sm_90": (228 - 1) * 1024, + "sm_80": (164 - 1) * 1024, + "sm_86": (100 - 1) * 1024, + "sm_89": (100 - 1) * 1024, +} + + +def get_smem_capacity_in_bytes(compute_capability: str) -> int: + if compute_capability not in SMEM_CAPACITY_MAP: + raise ValueError(f"Unsupported compute capability: {compute_capability}") + return SMEM_CAPACITY_MAP[compute_capability] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2873244d7cce9d8072f1fa71bbba1762022631b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py @@ -0,0 +1,386 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Tuple + +from cutlass.cutlass_dsl import ( + Boolean, + Integer, + Int32, + min, + extract_mlir_values, + new_from_mlir_values, + dsl_user_op, +) +from cutlass._mlir import ir +import cutlass.cute as cute + +############################################################################## +# Static persistent tile scheduler +############################################################################## + + +class WorkTileInfo: + """A class to represent information about a work tile. + + :ivar tile_idx: The index of the tile. + :type tile_idx: cute.Coord + :ivar is_valid_tile: Whether the tile is valid. + :type is_valid_tile: Boolean + """ + + def __init__(self, tile_idx: cute.Coord, is_valid_tile: Boolean): + self._tile_idx = tile_idx + self._is_valid_tile = Boolean(is_valid_tile) + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.tile_idx) + values.extend(extract_mlir_values(self.is_valid_tile)) + return values + + def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": + assert len(values) == 4 + new_tile_idx = new_from_mlir_values(self._tile_idx, values[:-1]) + new_is_valid_tile = new_from_mlir_values(self._is_valid_tile, [values[-1]]) + return WorkTileInfo(new_tile_idx, new_is_valid_tile) + + @property + def is_valid_tile(self) -> Boolean: + """Check latest tile returned by the scheduler is valid or not. Any scheduling + requests after all tasks completed will return an invalid tile. + + :return: The validity of the tile. + :rtype: Boolean + """ + return self._is_valid_tile + + @property + def tile_idx(self) -> cute.Coord: + """ + Get the index of the tile. + + :return: The index of the tile. + :rtype: cute.Coord + """ + return self._tile_idx + + +class PersistentTileSchedulerParams: + """A class to represent parameters for a persistent tile scheduler. + + This class is designed to manage and compute the layout of clusters and tiles + in a batched gemm problem. + + :ivar cluster_shape_mn: Shape of the cluster in (m, n) dimensions (K dimension cta count must be 1). + :type cluster_shape_mn: tuple + :ivar problem_layout_ncluster_mnl: Layout of the problem in terms of + number of clusters in (m, n, l) dimensions. + :type problem_layout_ncluster_mnl: cute.Layout + """ + + def __init__( + self, + problem_shape_ntile_mnl: cute.Shape, + cluster_shape_mnk: cute.Shape, + *, + loc=None, + ip=None, + ): + """ + Initializes the PersistentTileSchedulerParams with the given parameters. + + :param problem_shape_ntile_mnl: The shape of the problem in terms of + number of CTA (Cooperative Thread Array) in (m, n, l) dimensions. + :type problem_shape_ntile_mnl: cute.Shape + :param cluster_shape_mnk: The shape of the cluster in (m, n) dimensions. + :type cluster_shape_mnk: cute.Shape + + :raises ValueError: If cluster_shape_k is not 1. + """ + + if cluster_shape_mnk[2] != 1: + raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}") + + self.problem_shape_ntile_mnl = problem_shape_ntile_mnl + # cluster_shape_mnk is kept for reconstruction + self._cluster_shape_mnk = cluster_shape_mnk + self.cluster_shape_mn = cluster_shape_mnk[:2] + self._loc = loc + + # By default, we follow m major (col-major) raster order, so make a col-major layout + self.problem_layout_ncluster_mnl = cute.make_layout( + cute.ceil_div( + self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_ntile_mnl, self._cluster_shape_mnk]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.problem_shape_ntile_mnl, self._cluster_shape_mnk], self._values_pos + ): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return PersistentTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + @dsl_user_op + def get_grid_shape( + self, max_active_clusters: Int32, *, loc=None, ip=None + ) -> Tuple[Integer, Integer, Integer]: + """ + Computes the grid shape based on the maximum active clusters allowed. + + :param max_active_clusters: The maximum number of active clusters that + can run in one wave. + :type max_active_clusters: Int32 + + :return: A tuple containing the grid shape in (m, n, persistent_clusters). + - m: self.cluster_shape_m. + - n: self.cluster_shape_n. + - persistent_clusters: Number of persistent clusters that can run. + """ + + # Total ctas in problem size + num_ctas_mnl = tuple( + x * y + for x, y in zip( + self.problem_layout_ncluster_mnl.shape, self.cluster_shape_mn + ) + ) + (self.problem_layout_ncluster_mnl.shape[2],) + + num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) + + num_ctas_per_cluster = cute.size(self.cluster_shape_mn, loc=loc, ip=ip) + # Total ctas that can run in one wave + num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster + + num_persistent_ctas = min(num_ctas_in_problem, num_ctas_per_wave) + num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster + + return (*self.cluster_shape_mn, num_persistent_clusters) + + +class StaticPersistentTileScheduler: + """A scheduler for static persistent tile execution in CUTLASS/CuTe kernels. + + :ivar params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl + :type params: PersistentTileSchedulerParams + :ivar num_persistent_clusters: Number of persistent clusters that can be launched + :type num_persistent_clusters: Int32 + :ivar cta_id_in_cluster: ID of the CTA within its cluster + :type cta_id_in_cluster: cute.Coord + :ivar _num_tiles_executed: Counter for executed tiles + :type _num_tiles_executed: Int32 + :ivar _current_work_linear_idx: Current cluster index + :type _current_work_linear_idx: Int32 + """ + + def __init__( + self, + params: PersistentTileSchedulerParams, + num_persistent_clusters: Int32, + current_work_linear_idx: Int32, + cta_id_in_cluster: cute.Coord, + num_tiles_executed: Int32, + ): + """ + Initializes the StaticPersistentTileScheduler with the given parameters. + + :param params: Tile schedule related params, including cluster shape and problem_layout_ncluster_mnl. + :type params: PersistentTileSchedulerParams + :param num_persistent_clusters: Number of persistent clusters that can be launched. + :type num_persistent_clusters: Int32 + :param current_work_linear_idx: Current cluster index. + :type current_work_linear_idx: Int32 + :param cta_id_in_cluster: ID of the CTA within its cluster. + :type cta_id_in_cluster: cute.Coord + :param num_tiles_executed: Counter for executed tiles. + :type num_tiles_executed: Int32 + """ + self.params = params + self.num_persistent_clusters = num_persistent_clusters + self._current_work_linear_idx = current_work_linear_idx + self.cta_id_in_cluster = cta_id_in_cluster + self._num_tiles_executed = num_tiles_executed + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.num_persistent_clusters) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self.cta_id_in_cluster)) + values.extend(extract_mlir_values(self._num_tiles_executed)) + return values + + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "StaticPersistentTileScheduler": + assert len(values) == 6 + new_num_persistent_clusters = new_from_mlir_values( + self.num_persistent_clusters, [values[0]] + ) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[1]] + ) + new_cta_id_in_cluster = new_from_mlir_values( + self.cta_id_in_cluster, values[2:5] + ) + new_num_tiles_executed = new_from_mlir_values( + self._num_tiles_executed, [values[5]] + ) + return StaticPersistentTileScheduler( + self.params, + new_num_persistent_clusters, + new_current_work_linear_idx, + new_cta_id_in_cluster, + new_num_tiles_executed, + ) + + # called by host + @dsl_user_op + @staticmethod + def create( + params: PersistentTileSchedulerParams, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ): + """Initialize the static persistent tile scheduler. + + :param params: Parameters for the persistent + tile scheduler. + :type params: PersistentTileSchedulerParams + :param block_idx: The 3d block index in the format (bidx, bidy, bidz). + :type block_idx: Tuple[Integer, Integer, Integer] + :param grid_dim: The 3d grid dimensions for kernel launch. + :type grid_dim: Tuple[Integer, Integer, Integer] + + :return: A StaticPersistentTileScheduler object. + :rtype: StaticPersistentTileScheduler + """ + params = params + + # Calculate the number of persistent clusters by dividing the total grid size + # by the number of CTAs per cluster + num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size( + params.cluster_shape_mn, loc=loc, ip=ip + ) + + bidx, bidy, bidz = block_idx + + # Initialize workload index equals to the cluster index in the grid + current_work_linear_idx = Int32(bidz) + + # CTA id in the cluster + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + # Initialize number of tiles executed to zero + num_tiles_executed = Int32(0) + return StaticPersistentTileScheduler( + params, + num_persistent_clusters, + current_work_linear_idx, + cta_id_in_cluster, + num_tiles_executed, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: PersistentTileSchedulerParams, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Integer, Integer, Integer]: + """Calculates the grid shape to be launched on GPU using problem shape, + threadblock shape, and active cluster size. + + :param params: Parameters for grid shape calculation. + :type params: PersistentTileSchedulerParams + :param max_active_clusters: Maximum active clusters allowed. + :type max_active_clusters: Int32 + + :return: The calculated 3d grid shape. + :rtype: Tuple[Integer, Integer, Integer] + """ + + return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) + + # private method + def _get_current_work_for_linear_idx( + self, current_work_linear_idx: Int32, *, loc=None, ip=None + ) -> WorkTileInfo: + """Compute current tile coord given current_work_linear_idx and cta_id_in_cluster. + + :param current_work_linear_idx: The linear index of the current work. + :type current_work_linear_idx: Int32 + + :return: An object containing information about the current tile coordinates + and validity status. + :rtype: WorkTileInfo + """ + + is_valid = current_work_linear_idx < cute.size( + self.params.problem_layout_ncluster_mnl, loc=loc, ip=ip + ) + + cur_cluster_coord = self.params.problem_layout_ncluster_mnl.get_hier_coord( + current_work_linear_idx, loc=loc, ip=ip + ) + + # cur_tile_coord is a tuple of i32 values + cur_tile_coord = tuple( + Int32(x) * Int32(z) + Int32(y) + for x, y, z in zip( + cur_cluster_coord, + self.cta_id_in_cluster, + (*self.params.cluster_shape_mn, Int32(1)), + ) + ) + + return WorkTileInfo(cur_tile_coord, is_valid) + + @dsl_user_op + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + return self._get_current_work_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + @dsl_user_op + def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + self._current_work_linear_idx += Int32(advance_count) * Int32( + self.num_persistent_clusters + ) + self._num_tiles_executed += Int32(1) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed + + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..c6369c200e13ad280dfdecdb5cb4aa7ad081da4c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass/utils/tensormap_manager.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass +from enum import Enum, auto +from typing import Tuple + +from cutlass.cutlass_dsl import const_expr + +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + +import cutlass.cute as cute + + +class TensorMapUpdateMode(Enum): + """ + Enum class defining tensor map update modes. + + Modes: + GMEM: Update tensormap in global memory + SMEM: Load tensormap from global memory to shared memory, + update it in shared memory, then store back to global memory + """ + + GMEM = auto() # Update tensormap in global memory + SMEM = auto() # Update tensormap in shared memory + + +@dataclass(frozen=True) +class TensorMapManager: + """ + Manages TensorMap operations including initialization and updates. + Provides utilities to convert tensormap pointer to across different memory spaces. + """ + + tensormap_update_mode: TensorMapUpdateMode + bytes_per_tensormap: int + + # convert given cute.Pointer or cutlass.Int64 to a cute.Pointer to tensormap. + # address_space: the address space of the resulting tensormap pointer. It could be generic or gmem + def get_tensormap_ptr( + self, + ptr: cute.Pointer, + address_space=_cute_ir.AddressSpace.gmem, + ) -> cute.Pointer: + if address_space not in [ + _cute_ir.AddressSpace.gmem, + _cute_ir.AddressSpace.generic, + ]: + raise ValueError(f"Invalid address space: {address_space} for tensormap") + + gmem_ptr_i64 = ptr.toint().ir_value() + gmem_ptr_i64_align_ty = _cute_ir.ConstrainedIntType.get( + self.bytes_per_tensormap, gmem_ptr_i64.type.width + ) + gmem_ptr_i64_align = _cute_ir.assume(gmem_ptr_i64_align_ty, gmem_ptr_i64) + gmem_ptr_ty = _cute_ir.PtrType.get( + _cute_nvgpu_ir.TmaDescriptorTiledType.get(), + address_space, + self.bytes_per_tensormap, + ) + return _cute_ir.inttoptr(gmem_ptr_ty, gmem_ptr_i64_align) + + # init tensormap pointed by dst_ptr with the one inside copy_atom. + # dst_ptr should be pointing to a global memory location or a smem location + # warp_id specifies which warp to perform the initialization + @cute.jit + def init_tensormap_from_atom( + self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, warp_id: int + ) -> None: + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + if warp_idx == warp_id: + with cute.arch.elect_one(): + cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr) + cute.arch.sync_warp() + return + + # Perform a fence operation to ensure previous `init_tensormap_from_atom` calls have been completed + def fence_tensormap_initialization( + self, + ) -> None: + if self.tensormap_update_mode == TensorMapUpdateMode.GMEM: + cute.arch.fence_acq_rel_cta() + return + + # Perform a fence operation to ensure previous `update_tensormap` calls have been completed + def fence_tensormap_update( + self, + tensormap_ptr: cute.Pointer, + ) -> None: + cute.nvgpu.cpasync.fence_tma_desc_acquire(tensormap_ptr) + return + + @cute.jit + def update_tensormap( + self, + tensor_gmem: Tuple[cute.Tensor, ...], + tma_copy_atom: Tuple[cute.CopyAtom, ...], + tensormap_gmem_ptr: Tuple[cute.Pointer, ...], + warp_id: int, + tensormap_smem_ptr: Tuple[cute.Pointer, ...], + ) -> None: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # updates before touching tensormap in global memory + if warp_idx == warp_id: + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for copy_atom, tensor, smem_ptr in zip( + tma_copy_atom, tensor_gmem, tensormap_smem_ptr + ): + cute.nvgpu.cpasync.update_tma_descriptor( + copy_atom, tensor, smem_ptr + ) + # wait until it's safe to update tensormap in global memory + with cute.arch.elect_one(): + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_warp() + # updates to tensormap in global memory + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): + cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr) + else: + for copy_atom, tensor, gmem_ptr in zip( + tma_copy_atom, tensor_gmem, tensormap_gmem_ptr + ): + cute.nvgpu.cpasync.update_tma_descriptor( + copy_atom, tensor, gmem_ptr + ) + cute.arch.sync_warp() + cute.nvgpu.cpasync.fence_tma_desc_release() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06ea3f6f5f54b0b4f125c22504b06f41e8bf7697 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/__init__.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from .cutlass import * + +from ..base_dsl.ast_helpers import ( + loop_selector, + if_selector, + if_executor, + while_selector, + while_executor, + range, + range_constexpr, + range_dynamic, + const_expr, + dynamic_expr, + assert_executor, + bool_cast, + compare_executor, + any_executor, + all_executor, + range_value_check, + range_perf_warning, + cf_symbol_check, + redirect_builtin_function, + copy_members, + get_locals_or_none, +) + +from ..base_dsl import * +from ..base_dsl.dsl import extract_mlir_values, new_from_mlir_values +from ..base_dsl.typing import _binary_op_type_promote +from ..base_dsl._mlir_helpers.gpu import * +from ..base_dsl._mlir_helpers.op import dsl_user_op +from ..base_dsl.runtime import * +from ..base_dsl.runtime import cuda as cuda_helpers +from ..base_dsl.compiler import compile +from ..base_dsl.runtime.jit_arg_adapters import * diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..1630c873c7a1be3e013f966ea153c904f2b776ff --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass.py @@ -0,0 +1,1696 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +""" +This module provides a DSL for Cutlass Dialects. It also includes utils with +regarding to that dialect. +""" + +# Local module imports +from itertools import chain +from types import GenericAlias, SimpleNamespace, UnionType +from typing import Callable, Union, Type, List, Union, Sequence, ForwardRef, Any +import functools +import pkgutil +from dataclasses import is_dataclass, fields +from collections.abc import Sequence +import builtins + +from ..base_dsl import * +from ..base_dsl import compiler +from ..base_dsl.dsl import is_dynamic_expression, extract_mlir_values +from ..base_dsl.typing import * +from ..base_dsl.typing import DynamicExpression, get_mlir_types +from ..base_dsl.runtime.jit_arg_adapters import is_arg_spec_constexpr + +from ..base_dsl.ast_helpers import const_expr + +# MLIR Imports +from cutlass._mlir import ir, execution_engine, passmanager +from cutlass._mlir.dialects import arith, func, gpu, scf, cute, gpu as cutlass_gpu +from cutlass._mlir.dialects._ods_common import ( + get_op_result_or_op_results as _get_op_result_or_op_results, +) +from cutlass._mlir.extras import types as T + +# Helpers +from ..base_dsl._mlir_helpers import arith as cutlass_arith +from ..base_dsl._mlir_helpers import lru_cache_ir + +from ..base_dsl.ast_helpers import ( + loop_selector, + executor, + if_selector, + if_executor, + while_selector, + while_executor, + assert_executor, + const_expr, + dynamic_expr, + bool_cast, + compare_executor, + any_executor, + all_executor, + range_value_check, + range_perf_warning, + cf_symbol_check, +) + +from .cutlass_ast_decorators import ( + _loop_execute_range_dynamic, + _if_execute_dynamic, + _while_execute_dynamic, +) + +from .tree_utils import ( + is_constexpr_field, + tree_flatten, + tree_unflatten, + PyTreeDef, + is_frozen_dataclass, + DSLTreeFlattenError, +) +from ..base_dsl.runtime.jit_arg_adapters import JitArgAdapterRegistry + + +# ============================================================================= +# Cutlass DSL Base Abstract Class +# ============================================================================= + + +# Return a ctype class that represents the in-memory layout expected +# for a CuTe hierarchical tuple type. +def get_sparse_tuple_ctype(dyn): + # When there is a single dynamic value, the sparse CuTe + # representation is a single integer. + if isinstance(dyn, int): + return ctypes.c_int32 + + # For zero or greater than 1 dynamic values, the tuple + # representation will be a struct with a field for each dynamic + # value. The representation is flattened, even for hierarchical CuTe + # profiles (although we are only dealing with depth 1 inputs here). + class TupleDescriptor(ctypes.Structure): + _fields_ = [(f"x{idx}", ctypes.c_int32) for idx in range(len(dyn))] + + def __str__(self): + return f"struct<{str(self._fields_)}>" + + return TupleDescriptor + + +def is_cute_algebra_type(arg_spec): + # Walk through the arg_spec to check if it's a cute algebra type + _cute_algebra_type_aliases = ( + "Shape", + "Stride", + "Coord", + "Tile", + "IntTuple", + ) + + origin = get_origin(arg_spec) + if origin is Union: + for sub_ty in get_args(arg_spec): + sub_origin = get_origin(sub_ty) + if sub_origin is Tuple or ( + type(sub_origin) is type and issubclass(sub_origin, tuple) + ): + tuple_arg0 = get_args(sub_ty)[0] + if isinstance( + tuple_arg0, ForwardRef + ) and tuple_arg0.__forward_arg__ in (_cute_algebra_type_aliases): + return True + return False + + +def _get_c_pointers_cutlass(obj): + """ + This is an extended version of `get_c_pointers` that supports dataclasses, SimpleNamespace, and dict. + """ + if hasattr(obj, "__c_pointers__"): + return obj.__c_pointers__() + elif isinstance(obj, (tuple, list)): + return list(chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj)) + elif isinstance(obj, SimpleNamespace): + return list( + chain.from_iterable( + _get_c_pointers_cutlass(x) for x in obj.__dict__.values() + ) + ) + elif isinstance(obj, dict): + return list( + chain.from_iterable(_get_c_pointers_cutlass(x) for x in obj.values()) + ) + elif is_dataclass(obj): + return list( + chain.from_iterable( + _get_c_pointers_cutlass(getattr(obj, f.name)) + for f in fields(obj) + if not is_constexpr_field(f) + ) + ) + elif isinstance(obj, set): + raise DSLRuntimeError( + "Sets are not supported in get_c_pointers to ensure order preservation", + context="The DSL attempted to generate JIT function argument(s) for an argument of type set but failed.", + suggestion="Consider using a list or tuple instead", + ) + else: + # Try get adapter + adapter = JitArgAdapterRegistry.get_registered_adapter(type(obj)) + if adapter is not None: + return _get_c_pointers_cutlass(adapter(obj)) + return [] + + +class CutlassBaseDSL(BaseDSL): + """This abstract class provides a DSL for Cutlass.""" + + def __init__( + self, + name: str, + compiler_provider: Any, + pass_sm_arch_name: str, + device_compilation_only: bool = False, + preprocess: bool = False, + ): + super().__init__( + name=name, + dsl_package_name=["cutlass"], + compiler_provider=compiler_provider, + pass_sm_arch_name=pass_sm_arch_name, + device_compilation_only=device_compilation_only, + preprocess=preprocess, + ) + self._smem_usage_tracker: tuple = None + + # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. + def _is_tensor_descriptor(self, maybe_tensor_descriptor) -> bool: + return False + + # this method is not useful for cutlass_dsl, so we only provide a dummy implementation. + def _handle_tensor_descriptor( + self, maybe_tensor, arg_name: str, need_gpu_memory: bool + ) -> Any: + return False + + def _build_gpu_module(self, attrs): + self.gpu_module = gpu.GPUModuleOp(ir.StringAttr.get("kernels")) + with ir.InsertionPoint(self.gpu_module.bodyRegion.blocks.append(*[])): + pass + + for attr_name in attrs: + self.gpu_module.attributes[attr_name] = ir.Attribute.parse(attrs[attr_name]) + + def _get_pipeline(self, pipeline): + pipeline = super()._get_pipeline(pipeline) + if pipeline == None: + # cubin format is required to be cubin as we launch cuda module at python level. + return ( + "builtin.module(cute-to-nvvm{cubin-format=bin " + + self.compile_options.to_str() + + "})" + ) + + return pipeline + + def preprocess_pipeline(self, pipeline, arch) -> str: + pipeline = super().preprocess_pipeline(pipeline, arch) + pipeline = pipeline.rstrip(")") + ",external-kernel-for-gpu-launch)" + return pipeline + + def _enter_gpu_module(self): + return ir.InsertionPoint(self.gpu_module.bodyRegion.blocks[0]) + + def _generate_kernel_attrs(self, config: BaseDSL.LaunchConfig) -> dict: + assert isinstance( + config, BaseDSL.LaunchConfig + ), f"Expect LaunchConfig for @kernel, but got {type(config)}" + + ret = {} + # generate launch bound attr from LaunchConfig + max_threads = ", ".join(map(str, config.block)) + ret["nvvm.reqntid"] = ir.Attribute.parse(f"array") + # min_blocks_per_mp is optional for kernel + min_blocks = config.min_blocks_per_mp + if min_blocks > 0: + ret["nvvm.minctasm"] = ir.Attribute.parse(f"{min_blocks} : i32") + return ret + + @lru_cache(maxsize=1) + def get_version(self): + """ + Get the version of cutlass dsl, used for computing the hash key of the cache. + Including source python files and the shared library. + """ + dsl_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + # get the version hash of the cutlass shared library + version_hash = hashlib.sha256() + # update the version hash of the source python files + for lib in pkgutil.walk_packages([dsl_path], prefix="cutlass."): + try: + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + version_hash.update(f.read()) + except Exception: + raise DSLRuntimeError( + f"Failed to read module file {lib.name}. The file may not exist or may not be readable." + "Please re-install the package." + ) + try: + # update the version hash of the cutlass shared library + with open( + os.path.join(dsl_path, "_mlir/_mlir_libs/libCutlassIRPythonCAPI.so"), + "rb", + ) as f: + while True: + chunk = f.read(1024**2) + if not chunk: + break + version_hash.update(chunk) + except Exception: + raise DSLRuntimeError( + f"Failed to read the shared library file libCutlassIRPythonCAPI.so." + "The file may not exist or may not be readable." + "Please re-install the package." + ) + + return version_hash + + @staticmethod + def track_smem_allocator(allocator, callback): + """ + Tracks shared memory usage for kernel functions. + Find and set allocator to its parent dsl object. + """ + frame = inspect.currentframe().f_back + while frame: + obj = frame.f_locals.get("self", None) + if obj and isinstance(obj, CutlassBaseDSL): + obj._set_smem_tracking(allocator, callback) + return + frame = frame.f_back + warnings.warn("Cannot find parent dsl for allocator!", UserWarning) + + def _set_smem_tracking(self, allocator, callback): + # Registers an allocator and callback for current dsl + self._smem_usage_tracker = (allocator, callback) + + def _reset_smem_tracking(self): + # Clear an allocator and callback for current dsl + self._smem_usage_tracker = None + + def _get_smem_usage(self) -> int: + # Treat final allocated bytes of allocator as smem usage + if not self._smem_usage_tracker: + return 0 + allocator, callback = self._smem_usage_tracker + return callback(allocator) + + def _kernel_helper(self, funcBody, *args, **kwargs): + class _CutlassIrKernelGenHelper(BaseDSL._KernelGenHelper): + def __init__(self, dsl: CutlassBaseDSL): + super().__init__() + self.dsl = dsl + self.dsl._reset_smem_tracking() + + def generate_func_op(self, arg_types, arg_attrs, kernel_name, loc=None): + super().generate_func_op(arg_types, arg_attrs, kernel_name) + self.func_op = func.FuncOp( + kernel_name, ir.FunctionType.get(arg_types, []), loc=loc + ) + if arg_attrs is not None: + log().debug(arg_attrs) + self.func_op.arg_attrs = arg_attrs + return self.func_op + + def generate_func_ret_op(self): + return func.ReturnOp([]) + + def get_func_body_start(self): + assert self.func_op is not None, "Invalid func_op is not expected!" + return self.func_op.add_entry_block() + + def generate_launch_op(self, *args, **kwargs): + # Extract args and do validation + kernelSym = kwargs.get("kernelSym", None) + kernelOperands = kwargs.get("kernelOperands", None) + requiredArgs = kwargs.get("requiredArgs", None) + assert kernelSym is not None, "kernelSym being None is not expected!" + assert ( + requiredArgs is not None + ), "requiredArgs being None is not expected!" + assert ( + kernelOperands is not None + ), "kernelOperands being None is not expected!" + assert isinstance( + requiredArgs.config, BaseDSL.LaunchConfig + ), f"Expect LaunchConfig for @kernel, but got {type(requiredArgs.config)}" + + cfg = requiredArgs.config + + # Apply to grid, block, and cluster if present + cfg.grid = [to_index(size) for size in cfg.grid] + cfg.block = [to_index(size) for size in cfg.block] + if cfg.has_cluster: + cfg.cluster = [to_index(size) for size in cfg.cluster] + + smem_usage = self.dsl._get_smem_usage() + if any(not isinstance(x, int) for x in [cfg.smem, smem_usage]): + pass # cannot compare dynamic value inside kernel to launch op in py + elif cfg.auto_smem: + cfg.smem = smem_usage + elif smem_usage > cfg.smem: + warnings.warn( + f"Potential error: specified kernel launch smem bytes " + f"({cfg.smem}) is smaller than kernel usage ({smem_usage})!", + UserWarning, + ) + cfg.smem = const(cfg.smem) + + if not isinstance(cfg.async_deps, (list, tuple)): + cfg.async_deps = [cfg.async_deps] + is_async = len(cfg.async_deps) > 0 + token = gpu.launch_func( + gpu.AsyncTokenType.get() if is_async else None, + cfg.async_deps, + kernelSym, + *cfg.grid, + *cfg.block, + kernelOperands, + **dict( + zip( + ("cluster_size_x", "cluster_size_y", "cluster_size_z"), + tuple(cfg.cluster), + ) + ), + dynamic_shared_memory_size=cfg.smem, + ) + return token if is_async else None + + return KernelLauncher( + self, + lambda: _CutlassIrKernelGenHelper(self), + funcBody, + *args, + **kwargs, + ) + + def _preprocess_launch_config_args(self, args, kwargs): + """Helper to preprocess args and kwargs for LaunchConfig""" + if "stream" in kwargs: + kwargs["async_deps"] = kwargs.pop("stream") + + def mangle_name(self, function_name, args, args_spec: inspect.FullArgSpec): + """Mangle the name of the function to avoid conflicts with other functions""" + function_name = "cutlass_" + function_name + return super().mangle_name(function_name, args, args_spec) + + def _validate_arg(self, arg, arg_index, arg_name, arg_annotation): + """ + Validates if the arg is really of the annotated type. + """ + + if ( + is_arg_spec_constexpr(arg_annotation, arg_name, arg_index, None) + or arg_annotation is Any + ): + pass + else: + origin = get_origin(arg_annotation) + # Handle special case where annotation is Type[X] but arg is an actual type + if origin is type and isinstance(arg, type): + # Get the expected base type from Type[X] + expected_base = get_args(arg_annotation)[0] + if not issubclass(arg, expected_base): + return DSLRuntimeError( + f"expects argument #{arg_index+1} ({arg_name}) to be Type[{expected_base}], but got {arg}" + ) + # Handle Union types and generic types + elif origin is Union or isinstance(arg_annotation, UnionType): + # For Union types, check if arg matches any of the allowed types + allowed_types = get_args(arg_annotation) + if not any( + (ty is Any) + or (isinstance(ty, type) and isinstance(arg, ty)) + or (get_origin(ty) is tuple and isinstance(arg, tuple)) + for ty in allowed_types + ): + return DSLRuntimeError( + f"expects argument #{arg_index+1} ({arg_name}) to be one of {allowed_types}, but got {type(arg)}" + ) + elif isinstance(arg_annotation, type): + # Handle simple type annotations + if not isinstance(arg, arg_annotation) and arg is not None: + return DSLRuntimeError( + f"expects argument #{arg_index+1} ({arg_name}) to be {arg_annotation}, but got {type(arg)}" + ) + # Everything looks good if we are here + return None + + def _generate_jit_func_args_for_known_types( + self, + func, + arg, + arg_name, + arg_spec, + arg_index, + *, + is_host=True, + ): + jit_arg_type, jit_arg_attr, jit_exec_arg = [], [], [] + default_attr = ir.DictAttr.get({}) + + ( + jit_exec_arg, + jit_arg_type, + jit_arg_attr, + ) = super()._generate_jit_func_args_for_known_types( + func, arg, arg_name, arg_spec, arg_index, is_host=is_host + ) + + if jit_arg_type is not None and len(jit_arg_type) == 0: + # Handle DSL specific types + if is_cute_algebra_type(arg_spec): + dyn_vals = extract_mlir_values(arg) + if dyn_vals: + # Handle dynamic types + jit_arg_type.extend([v.type for v in dyn_vals]) + jit_arg_attr.extend([default_attr] * len(dyn_vals)) + jit_exec_arg.extend(get_c_pointers(arg) if is_host else dyn_vals) + else: + jit_exec_arg = jit_arg_type = jit_arg_attr = None + elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( + arg, "__new_from_mlir_values__" + ): + # Try tree_flatten + try: + dyn_vals, _ = tree_flatten(arg) + except DSLTreeFlattenError: + # If fails, just return the original arg + return jit_exec_arg, jit_arg_type, jit_arg_attr + + if dyn_vals: + jit_arg_type.extend([v.type for v in dyn_vals]) + jit_arg_attr.extend([default_attr] * len(dyn_vals)) + jit_exec_arg.extend( + _get_c_pointers_cutlass(arg) if is_host else dyn_vals + ) + else: + # If tree flatten yields empty list, treat it as a constexpr thing + # Like a dataclass with all fields are constexpr, or an empty tuple or list + jit_exec_arg = jit_arg_type = jit_arg_attr = None + return jit_exec_arg, jit_arg_type, jit_arg_attr + + def _generate_execution_arguments_for_known_types( + self, arg, arg_spec, arg_name, i, fop_args, iv_block_args + ): + ir_arg, iv_block_args = super()._generate_execution_arguments_for_known_types( + arg, arg_spec, arg_name, i, fop_args, iv_block_args + ) + if not ir_arg: + # Handling DSL specific types + if is_cute_algebra_type(arg_spec): + n_args = len(get_mlir_types(arg)) + blk_args = fop_args[iv_block_args : iv_block_args + n_args] + ir_arg.append(new_from_mlir_values(arg, blk_args)) + iv_block_args += n_args + elif not hasattr(arg, "__extract_mlir_values__") and not hasattr( + arg, "__new_from_mlir_values__" + ): + # Try tree_unflatten + try: + dyn_vals, tree_def = tree_flatten(arg) + block_args = fop_args[iv_block_args : iv_block_args + len(dyn_vals)] + ir_arg.append(tree_unflatten(tree_def, block_args)) + iv_block_args += len(dyn_vals) + except DSLTreeFlattenError: + return ir_arg, iv_block_args + + return ir_arg, iv_block_args + + +# ============================================================================= +# Cute DSL Class +# ============================================================================= + + +class CuTeDSL(CutlassBaseDSL): + """ + This is a concrete DSL subclass for the CuTe dialect. + """ + + def __init__(self): + name = "CUTE_DSL" + compiler_provider = compiler.Compiler(passmanager, execution_engine) + pass_sm_arch_name = "cubin-chip" + + super().__init__(name, compiler_provider, pass_sm_arch_name, preprocess=True) + + +# ============================================================================= +# KernelLauncher +# ============================================================================= + + +class KernelLauncher: + """ + This class is used to launch a kernel function. + Usage: + ```python + @cute.kernel + def kernel(arg1, arg2, ...): + ... + + @cute.jit + def launch_kernel(): + kernel(arg1, arg2, ...).launch(grid=[1, 1, 1], block=[1, 1, 1], ...) + # or + kernel(arg1, arg2, ...)(grid=[1, 1, 1], block=[1, 1, 1], ...) + ``` + """ + + def __init__( + self, + dsl: "CutlassBaseDSL", + kernelGenHelper: BaseDSL._KernelGenHelper, + funcBody, + *func_args, + **func_kwargs, + ): + self.dsl = dsl + self.kernelGenHelper = kernelGenHelper + self.funcBody = funcBody + self.func_args = func_args + self.func_kwargs = func_kwargs + + self._check_func_args(funcBody, *func_args, **func_kwargs) + + def _check_func_args(self, funcBody, *func_args, **func_kwargs): + # Get function signature + sig = inspect.signature(funcBody) + + # func_args and func_kwargs should match funcBody's signature, + # no extra or missing arguments. + try: + sig.bind(*func_args, **func_kwargs) + except TypeError as e: + raise DSLRuntimeError( + f"Failed to bind arguments to function `{funcBody.__name__}` with signature `{sig}`", + cause=e, + ) + + def smem_usage(self) -> int: + """ + Check smem usage for this kernel, only available after `launch` + """ + return self.dsl._get_smem_usage() + + def launch(self, *args, **kwargs): + self.dsl.frame = inspect.currentframe().f_back + self.dsl._preprocess_launch_config_args(args, kwargs) + config = self.dsl.LaunchConfig(*args, **kwargs) + + kernel_generator = self.dsl.kernel_launcher( + requiredArgs=["config"], + unitAttrNames=["gpu.kernel", "cute.kernel"], + valueAttrDict=self.dsl._generate_kernel_attrs(config), + kernelGenHelper=self.kernelGenHelper, + )(self.funcBody) + + ret, name = kernel_generator(*self.func_args, **self.func_kwargs, config=config) + self.dsl.kernel_symbols.append(name) + self.dsl.frame = None + return ret.launch_op_ret + + def __call__(self, *args, **kwargs): + return self.launch(*args, **kwargs) + + +# ============================================================================= +# Utils +# ============================================================================= +def _filter_readonly_frozen_dataclass( + iter_args: List[Any], items_to_filter: List[Any], full_write_args_count: int +) -> List[Any]: + """ + Filter items based on whether corresponding iter_args are frozen dataclasses. + + This function filters items (which can be values or names) based on the same + logic: keep items if they correspond to full-write arguments (index < full_write_args_count) + or if the corresponding iter_arg is not a frozen dataclass. + + Args: + iter_args: List of arguments to check for frozen dataclass status + items_to_filter: List of items to filter (values or names) + full_write_args_count: Number of arguments that are always written (not read-only) + + Returns: + Filtered list of items + + Examples: + # Filter values (original remove_read_only_frozen_dataclass behavior) + filtered_values = _filter_readonly_frozen_dataclass(iter_args, iter_args, full_write_args_count) + + # Filter names (original filter_readonly_frozen_dataclass_names behavior) + filtered_names = _filter_readonly_frozen_dataclass(iter_args, iter_args_names, full_write_args_count) + """ + return [ + item + for i, item in enumerate(items_to_filter) + if i < full_write_args_count or not is_frozen_dataclass(iter_args[i]) + ] + + +def remove_read_only_frozen_dataclass( + iter_args: List[Any], full_write_args_count: int +) -> List[Any]: + """Filter out frozen dataclass arguments that are not full-write arguments.""" + return _filter_readonly_frozen_dataclass( + iter_args, iter_args, full_write_args_count + ) + + +def filter_readonly_frozen_dataclass_names( + iter_args: List[Any], iter_args_names: List[str], full_write_args_count: int +) -> List[str]: + """Filter names based on whether corresponding iter_args are frozen dataclasses.""" + return _filter_readonly_frozen_dataclass( + iter_args, iter_args_names, full_write_args_count + ) + + +def insert_read_only_frozen_dataclass( + iter_args: List[Any], original_iter_args: List[Any], full_write_args_count: int +) -> List[Any]: + """ + Insert read-only frozen dataclass arguments back into the iteration arguments. + + This function takes the new iteration arguments and the original arguments, + and preserves frozen dataclass instances from the original arguments while + using the new arguments for non-frozen dataclass instances. + + Args: + iter_args: New iteration arguments to use for non-frozen dataclass instances + original_iter_args: Original iteration arguments to preserve frozen dataclass instances from + full_write_args_count: Number of arguments that are always written (not read-only) + + Returns: + List of arguments with frozen dataclass instances preserved from original + """ + # Take full-write arguments from new iter_args + full_write_args = ( + iter_args[:full_write_args_count] if full_write_args_count > 0 else [] + ) + + # Process remaining arguments: preserve frozen dataclass from original, use new for others + remaining_original = original_iter_args[full_write_args_count:] + remaining_new = iter_args[full_write_args_count:] + + def process_remaining_arg(original_arg, new_arg_iter): + """Process a single remaining argument, preserving frozen dataclass if present""" + return original_arg if is_frozen_dataclass(original_arg) else next(new_arg_iter) + + # Use zip to pair original args with new args, then map the processing function + new_arg_iter = iter(remaining_new) + processed_remaining = [ + process_remaining_arg(orig_arg, new_arg_iter) for orig_arg in remaining_original + ] + + return full_write_args + processed_remaining + + +def unpack_to_irvalue( + mixed_values: List[Any], body_name: str, full_write_args_count: int +) -> Tuple[List[ir.Value], PyTreeDef]: + log().debug("===--- Values UNPack") + for idx, packed in enumerate(mixed_values): + log().debug("[%d]: will-unpacked: [type:%s] %s", idx, type(packed), packed) + + try: + unpacked_values, treedef = tree_flatten( + remove_read_only_frozen_dataclass(mixed_values, full_write_args_count) + ) + except DSLTreeFlattenError as e: + raise DSLRuntimeError( + f"The '{body_name}' statement encountered a user-defined Python object, which cannot be automatically converted into an dynamic expression.", + context={ + e.message: ( + f"All expressions within '{body_name}' must be dynamic expressions, " + "mixing Python objects and dynamic expressions is not supported. " + "The DSL failed to convert the Python object into dynamic expressions." + ) + }, + suggestion=( + f"Please ensure '{e.type_str}' implements the '{DynamicExpression.__name__}' or mark with `dataclass`, " + f"so it can be treated as a valid dynamic expression or mark '{body_name}' as a constant expression if conditions are Python objects." + ), + ) + + log().debug("------------------ ") + for idx, unpacked in enumerate(unpacked_values): + log().debug("[%d]: unpacked values: %s", idx, unpacked) + log().debug("treedef: %s", treedef) + log().debug("------------------ ") + + return unpacked_values, treedef + + +def pack_from_irvalue( + ir_values: List["ir.Value"], + pytree_def: PyTreeDef, + mixed_values: List[Any], + full_write_args_count: int, +) -> List[Any]: + """ + Packs MLIR values into a list of mixed values. + """ + log().debug("===--- Values Pack (%d)", len(ir_values)) + for idx, value in enumerate(ir_values): + log().debug("[%d]: will-packed: %s", idx, value) + log().debug("treedef: %s", pytree_def) + log().debug("------------------ ") + + unflattened = tree_unflatten(pytree_def, ir_values) + return insert_read_only_frozen_dataclass( + unflattened, mixed_values, full_write_args_count + ) + + +def to_index(value): + """Converts a value to an index, either by casting or coercing to int.""" + if is_dynamic_expression(value): + if isinstance(value, Numeric): + value = value.ir_value() + assert ir.IntegerType.isinstance( + value.type + ), f"expects integer type, but got {value.type}" + res = arith.index_cast(T.index(), value) + else: + res = const(int(value), ty=T.index()) + + return res + + +def _validate_iter_args_structure(iter_args, ir_values): + """ + Validates that iter_args structure contains the same number of atomic values + as there are IR values. + + Args: + iter_args: Original iteration arguments, possibly nested sequences + ir_values: Flattened MLIR values extracted from iter_args + + Returns: + bool: True if the number of atomic values in iter_args matches + the number of values in ir_values + """ + # Handle non-sequence case + if not isinstance(iter_args, (tuple, list, set)): + return not isinstance(ir_values, (tuple, list, set)) or len(ir_values) == 1 + + # If we have a sequence but ir_values isn't one, there's a mismatch + if not isinstance(ir_values, (tuple, list, set)): + return False + + # Count all non-sequence values recursively + def count_values(args): + if not isinstance(args, (tuple, list, set)): + return 1 + else: + return sum(count_values(arg) for arg in args) + + return count_values(iter_args) == len(ir_values) + + + +# ============================================================================= +# DSL implementation of Python Build-in Operators +# ============================================================================= + + +def _minmax(op, *args, loc=None, ip=None): + """Computes the minimum or maximum value from the provided arguments.""" + from ..base_dsl.typing import _binary_op, _binary_op_type_promote + + # AST Traversal doesn't support early exit in if executor + x = None + res = None + if len(args) == 1: + # Handle case for min([a, b, c, d, ..]) + if hasattr(args[0], "__iter__"): + x = op(*tuple(args[0])) + # Handle case for min(a) + else: + x = args[0] + # Handle case for min(a, b, c, ...) and min([x, y], [b]) and min(a, (x, y, z)) + elif len(args) > 1: + res, *xs = tuple(args) + for x in xs: + lhs = as_numeric(op(res, loc=loc, ip=ip)) + rhs = as_numeric(op(x, loc=loc, ip=ip)) + emitter = getattr(cutlass_arith, f"_{op.__name__}") + + lhs, rhs, res_type = _binary_op_type_promote(lhs, rhs, promote_bool=True) + + if isinstance(lhs.value, cutlass_arith.ArithValue) and isinstance( + lhs, Integer + ): + lhs_val = lhs.value.with_signedness(lhs.signed) + else: + lhs_val = lhs.value + + if isinstance(rhs.value, cutlass_arith.ArithValue) and isinstance( + rhs, Integer + ): + rhs_val = rhs.value.with_signedness(rhs.signed) + else: + rhs_val = rhs.value + + res = res_type(emitter(lhs_val, rhs_val), loc=loc, ip=ip) + x = res + else: + raise DSLNotImplemented(f"{type(args)} is not supported") + return x + + +def min(*args, loc=None, ip=None): + """Computes the minimum value from the provided arguments. + + This function differs from Python's built-in min() in that the return type + is determined by the static types of the inputs, not their dynamic values. + + :param args: One or more values or iterables to find the minimum of + :type args: tuple + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The minimum value among all inputs + :rtype: Numeric + :raises DSLNotImplemented: If the input type is not supported + + Supports multiple calling patterns: + + - min(a): Returns a + - min([a, b, c, ...]): Returns minimum of all elements in the iterable + - min(a, b, c, ...): Returns minimum of all arguments + - min([x, y], [b]): Returns minimum across all elements in all iterables + - min(a, (x, y, z)): Returns minimum across all elements + + Examples: + + .. code-block:: python + + # Find minimum of two values + result = min(x, y) + + # Find minimum of multiple values + result = min(a, b, c, d) + + # Find minimum of values in a list + values = [a, b, c, d] + result = min(values) + + # Find minimum across mixed arguments + result = min(x, [y, z]) + + Difference from Python's built-in min(): + + .. code-block:: python + + # In Python, the return type depends on the dynamic values: + a = 5 + b = 3.14 + result = min(a, b) # Returns 3.14 (float) + + # In this DSL implementation, the return type is determined statically: + a = Int32(5) + b = Float32(3.14) + result = min(a, b) # Return type is determined by the type of operands, not values + """ + return _minmax(min, *args, loc=loc, ip=ip) + + +def max(*args, loc=None, ip=None): + """Computes the maximum value from the provided arguments. + + This function differs from Python's built-in max() in that the return type + is determined by the static types of the inputs, not their dynamic values. + + :param args: One or more values or iterables to find the maximum of + :type args: tuple + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The maximum value among all inputs + :rtype: Numeric + :raises DSLNotImplemented: If the input type is not supported + + Supports multiple calling patterns: + + - max(a): Returns a + - max([a, b, c, ...]): Returns maximum of all elements in the iterable + - max(a, b, c, ...): Returns maximum of all arguments + - max([x, y], [b]): Returns maximum across all elements in all iterables + - max(a, (x, y, z)): Returns maximum across all elements + + Examples: + + .. code-block:: python + + # Find maximum of two values + result = max(x, y) + + # Find maximum of multiple values + result = max(a, b, c, d) + + # Find maximum of values in a list + values = [a, b, c, d] + result = max(values) + + # Find maximum across mixed arguments + result = max(x, [y, z]) + + Difference from Python's built-in max(): + + .. code-block:: python + + # In Python, the return type depends on the dynamic values: + a = 5 + b = 3.14 + result = max(a, b) # Returns 5 (int) + + # In this DSL implementation, the return type is determined statically: + a = Int32(5) + b = Float32(3.14) + result = max(a, b) # Return type is determined by the type of operands, not values + """ + return _minmax(max, *args, loc=loc, ip=ip) + + +def and_(*args, loc=None, ip=None): + """AND operation for value in DSL numeric types. + + :param *args: One or more numeric values to AND together + :type *args: Numeric + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The result of the logical AND operation + :rtype: Numeric + :raises ValueError: If no arguments are provided + + Supports multiple calling patterns: + + - and_(a): Returns a + - and_(a, b, c, ...): if a is truthy, returns and_(b, c, ...), otherwise returns a + + All arguments must be of the same type. + + Examples: + + .. code-block:: python + + # In Python, 'and' returns the second operand if the first is truthy, + # otherwise it returns the first operand + a = 5 + b = 3 + result = a and b # Returns 3 + + # In this DSL implementation, the behavior is similar but works with DSL types + a = Int32(5) + b = Int32(3) + result = and_(a, b) # Returns b + """ + if len(args) == 0: + raise ValueError("and_() requires at least one argument") + + if len(args) == 1: + return args[0] + + def and_op(lhs, rhs): + if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): + raise DSLNotImplemented(f"{type(lhs)} is not supported") + elif isinstance(lhs, (int, float, bool)) and isinstance( + rhs, (int, float, bool) + ): + return lhs and rhs + else: + return as_numeric(lhs).__dsl_and__(as_numeric(rhs)) + + return functools.reduce(and_op, args[1:], args[0]) + + +def or_(*args, loc=None, ip=None): + """Logical OR operation for DSL numeric types. + + :param *args: One or more numeric values to OR together + :type *args: Numeric + :param loc: Source location for MLIR operation tracking + :type loc: object, optional + :param ip: Insertion point for MLIR operation + :type ip: object, optional + :return: The result of the logical OR operation + :rtype: Numeric + :raises ValueError: If no arguments are provided + + Supports multiple calling patterns: + + - or_(a): Returns a + - or_(a, b, c, ...): if a is truthy, returns a, otherwise returns or_(b, c, ...) + + Examples: + + .. code-block:: python + + # In Python, 'or' returns the first operand if it's truthy, + # otherwise it returns the second operand + a = 5 + b = 3 + result = a or b # Returns 5 + + # In this DSL implementation, the behavior is similar but works with DSL types + a = Int32(5) + b = Int32(3) + result = or_(a, b) # Returns a + """ + if len(args) == 0: + raise ValueError("or_() requires at least one argument") + + if len(args) == 1: + return args[0] + + def or_op(lhs, rhs): + if not isinstance(lhs, (Numeric, cutlass_arith.ArithValue, int, float, bool)): + raise DSLNotImplemented(f"{type(lhs)} is not supported") + elif isinstance(lhs, (int, float, bool)) and isinstance( + rhs, (int, float, bool) + ): + return lhs or rhs + else: + return as_numeric(lhs).__dsl_or__(as_numeric(rhs)) + + return functools.reduce(or_op, args[1:], args[0]) + + +def all_(iterable): + """Logical AND operation for all elements in an iterable. + + Returns True if all elements in the iterable are truthy, otherwise False. + This is the DSL equivalent of Python's built-in all() function. + + :param iterable: An iterable containing values to check + :type iterable: Iterable + :return: True if all elements are truthy, False otherwise + :rtype: Boolean + + Examples: + + .. code-block:: python + + # Check if all values are non-zero + values = [Int32(1), Int32(2), Int32(3)] + result = all_(values) # Returns True + + # Check if all conditions are met + conditions = [a > 0, b < 10, c != 0] + result = all_(conditions) # Returns True if all conditions are met + """ + bool_iterable = [Boolean(i) for i in iterable] + return functools.reduce( + lambda lhs, rhs: lhs.__dsl_and__(rhs) if hasattr(lhs, "__dsl_and__") else lhs, + bool_iterable, + Boolean(True), + ) + + +def any_(iterable): + """Logical OR operation for any element in an iterable. + + Returns True if any element in the iterable is truthy, otherwise False. + This is the DSL equivalent of Python's built-in any() function. + + :param iterable: An iterable containing values to check + :type iterable: Iterable + :return: True if any element is truthy, False otherwise + :rtype: Boolean + + Examples: + + .. code-block:: python + + # Check if any value is non-zero + values = [Int32(0), Int32(0), Int32(3)] + result = any_(values) # Returns True + + # Check if any condition is met + conditions = [a > 10, b < 0, c != 0] + result = any_(conditions) # Returns True if any condition is met + """ + bool_iterable = [Boolean(i) for i in iterable] + return functools.reduce( + lambda lhs, rhs: lhs.__dsl_or__(rhs) if hasattr(lhs, "__dsl_or__") else lhs, + bool_iterable, + Boolean(False), + ) + + +# ============================================================================= +# Conditional Expression +# ============================================================================= + + +def select_(cond, if_value, else_value): + def _as_scalar(value): + if isinstance(value, list): + if len(value) == 1: + return value[0] + else: + raise DSLRuntimeError( + "Conditional expression must have exactly one value in all expressions" + ) + return value + + if not is_dynamic_expression(cond): + raise DSLRuntimeError("Conditional expression must be dynamic") + + # Extract MLIR values + cond = extract_mlir_values(cond) + if is_dynamic_expression(if_value): + if_value = extract_mlir_values(if_value) + else: + if_value = const(if_value) + if is_dynamic_expression(else_value): + else_value = extract_mlir_values(else_value) + else: + else_value = const(else_value) + + return arith.SelectOp( + _as_scalar(cond), _as_scalar(if_value), _as_scalar(else_value) + ).result + + +# ============================================================================= +# Terminator +# ============================================================================= + + +def yield_out(args=[], loc=None, ip=None): + """ + Generate a yield operation. It it used to return values from a loop, if-else, or while region. + """ + scf.yield_(extract_mlir_values(args), loc=loc, ip=ip) + + +# ============================================================================= +# For Loop +# ============================================================================= + + +class LoopUnroll(ir.Attribute): + def __init__(self, **kwargs): + valid_keys = set(["count", "full"]) + def to_mlir_attr(val): + if isinstance(val, bool): + return "true" if val else "false" + elif isinstance(val, int): + return f"{val} : i32" + else: + raise DSLNotImplemented(f"{type(val)} is not supported") + + cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs} + if kwargs.get("count", None) == 1: + cfg["disable"] = "true" + + unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">" + + super().__init__( + ir.Attribute.parse(f"#llvm.loop_annotation") + ) + + +def for_generate( + start, + stop=None, + step=None, + iter_args: Optional[Sequence[ir.Value]] = None, + *, + unroll: LoopUnroll = None, + prefetch_stages=None, + loc=None, + ip=None, +): + """ + scf.for with yield support + """ + + if step is None: + step = 1 + if stop is None: + stop = start + start = 0 + start = const(start) + params = [start, stop, step] + for i, p in enumerate(params): + if isinstance(p, int): + p = const(p) + elif isinstance(p, float): + raise DSLRuntimeError(f"{p=} must be int.") + elif isinstance(p, Integer): + p = p.ir_value() + params[i] = p + + start, stop, step = params + + def _createI32Attr(value): + if not isinstance(value, int): + raise DSLRuntimeError(f"value must be int.") + return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), value) + + ir_iter_args = extract_mlir_values(iter_args) if iter_args is not None else None + if not _validate_iter_args_structure(iter_args, ir_iter_args): + raise DSLRuntimeError("iter_args: Elements should be extractable as ir.Value.") + for_op = scf.ForOp(start, stop, step, ir_iter_args, loc=loc, ip=ip) + if unroll is not None: + for_op.attributes["loop_annotation"] = unroll + + if prefetch_stages is not None: + for_op.attributes["cutlass.pipelining"] = _createI32Attr(prefetch_stages) + + iv = for_op.induction_variable + new_results = new_from_mlir_values(iter_args, for_op.results) + new_iter_args = new_from_mlir_values(iter_args, for_op.inner_iter_args) + new_iter_args = () if new_iter_args is None else tuple(new_iter_args) + + with ir.InsertionPoint(for_op.body): + if len(new_iter_args) > 1: + yield iv, new_iter_args, new_results + elif len(new_iter_args) == 1: + yield iv, new_iter_args[0], new_results[0] + else: + yield iv + + +# ============================================================================= +# Logical Operators +# ============================================================================= + + +def not_(lhs: Union[ir.Value, bool], *, loc=None, ip=None): + """ + Logical Not + """ + res = None + # Handle Python bool first to prevent infinite recursion + if type(lhs) == bool: + res = lhs ^ True + elif hasattr(lhs, "__dsl_not__"): + res = lhs.__dsl_not__(loc=loc, ip=ip) + elif is_dynamic_expression(lhs): + # If lhs is MLIR value, compute not using xor + res = arith.XOrIOp(lhs, const(1, lhs.type)).result + else: + res = bool(lhs) ^ True + + return res + + +# ============================================================================= +# If/Else +# ============================================================================= + + +def if_generate( + cond: Boolean, + then_body: Callable, + else_body: Optional[Callable] = None, + input_args: List[DslType] = None, + return_types: List[DslType] = None, + *, + loc=None, + ip=None, +) -> List: + """ + Generate an IfOp with optional else branch and return values. + + Args: + cond: The condition expression + then_body: Function to execute in then branch + else_body: Optional function to execute in else branch + input_args: Arguments to pass to branch bodies + return_types: Expected return types for the operation + loc: Optional location information + ip: Optional insertion point + + Returns: + List of DSL typed results + """ + input_args = input_args or [] + mlir_return_types = [] + + # Validate and collect MLIR return types (if provided). + if return_types is not None: + for t in return_types: + if not isinstance(t, DslType): + raise DSLRuntimeError(f"{t=} must be a DslType.") + mlir_return_types.append(t.mlir_type) + + # Determine whether there's an else branch. + has_else = else_body is not None + + # Create the IfOp. + if_op = scf.IfOp( + Boolean(cond).ir_value(), mlir_return_types, hasElse=has_else, loc=loc, ip=ip + ) + + def _execute_and_yield_out(body, input_args): + yield_vals = body(*input_args) + if return_types is not None: + if not isinstance(yield_vals, Iterable): + # body only return single element + yield_vals = [yield_vals] + + yield_vals = [t(r) for t, r in zip(return_types, yield_vals)] + yield_out(yield_vals) + + # Generate the body for 'then'. + with ir.InsertionPoint(if_op.then_block): + _execute_and_yield_out(then_body, input_args) + + # Generate the body for 'else' if provided. + if has_else: + with ir.InsertionPoint(if_op.else_block): + _execute_and_yield_out(else_body, input_args) + + # Collect MLIR results. + mlir_results = _get_op_result_or_op_results(if_op) + + if not isinstance(mlir_results, list): + mlir_results = [mlir_results] + + # Wrap the results with their DSL types. + if return_types is None: + return [] + + vals = [t(r) for t, r in zip(return_types, mlir_results)] + + if len(vals) == 1: + return vals[0] + + return vals + + +# ============================================================================= +# While Loop +# ============================================================================= + + +class WhileLoopContext: + """ + Context manager for a dynamic while loop. + """ + + def __init__( + self, + inputs: Sequence[Union[ir.Value, Numeric]], + condition: Callable[[Sequence[ir.Value]], ir.Value], + *, + loc=None, + ip=None, + ): + # Keep original inputs and allow recover original type information + self.inputs = inputs + + self.input_ir_values = extract_mlir_values(inputs) + + if not _validate_iter_args_structure(inputs, self.input_ir_values): + raise DSLRuntimeError("inputs: Elements should be extractable as ir.Value.") + + self.condition = condition + self.input_ir_types = [i.type for i in self.input_ir_values] + self.while_op = scf.WhileOp( + self.input_ir_types, self.input_ir_values, loc=loc, ip=ip + ) + + self.before_region = self.while_op.before + self.after_region = self.while_op.after + + self.before_region.blocks.append(*self.input_ir_types) + self.before_block = self.before_region.blocks[0] + + self.after_region.blocks.append(*self.input_ir_types) + self.after_block = self.after_region.blocks[0] + + def __enter__(self): + with ir.InsertionPoint(self.before_block): + args = new_from_mlir_values(self.inputs, self.before_block.arguments) + cond = self.condition(*args) + cond_ir_val = extract_mlir_values(cond) + scf.ConditionOp(cond_ir_val[0], [*self.before_block.arguments]) + self.ipoint_op = ir.InsertionPoint(self.after_block) + self.ipoint_op.__enter__() + return new_from_mlir_values(self.inputs, self.after_block.arguments) + + def __exit__(self, exc_type, exc_value, traceback): + self.ipoint_op.__exit__(exc_type, exc_value, traceback) + + @property + def results(self): + return new_from_mlir_values(self.inputs, self.while_op.results_) + + +def while_generate( + inputs: Sequence[Union[ir.Value, Numeric]], + condition: Callable[[Sequence[Union[ir.Value, Numeric]]], Union[ir.Value, Numeric]], + *, + loc=None, + ip=None, +) -> WhileLoopContext: + """ + Generate a WhileLoopContext for a dynamic loop. + """ + return WhileLoopContext(inputs, condition, loc=loc, ip=ip) + + +def equal(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs == rhs + + # Both sequence + if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): + # Short-circuit for unequal length + if len(lhs) != len(rhs): + return False + return all_(equal(l, r) for l, r in zip(lhs, rhs)) + return lhs == rhs + + +def not_equal(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs != rhs + + # Both sequence + if isinstance(lhs, Sequence) and isinstance(rhs, Sequence): + # Short-circuit for unequal length + if len(lhs) != len(rhs): + return True + return any_(not_equal(l, r) for l, r in zip(lhs, rhs)) + + if hasattr(lhs, "__ne__"): + return lhs != rhs + elif hasattr(rhs, "__ne__"): + return rhs != lhs + else: + return not_(equal(lhs, rhs)) + + +def in_(lhs, rhs): + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return lhs in rhs + + if not isinstance(rhs, Sequence): + raise DSLRuntimeError( + f"'in' not supported between instances of {type(lhs)} and {type(rhs)}" + ) + + return any_(equal(lhs, r) for r in rhs) + + +def _lte_gte(lhs, rhs, op): + def native_lte_gte(lhs, rhs, op): + match op: + case "<": + return lhs < rhs + case "<=": + if hasattr(lhs, "__le__"): + return lhs <= rhs + else: + return not_(lhs > rhs) + case ">": + return lhs > rhs + case ">=": + if hasattr(lhs, "__ge__"): + return lhs >= rhs + else: + return not_(lhs < rhs) + case _: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + + if not is_dynamic_expression(lhs) and not is_dynamic_expression(rhs): + return native_lte_gte(lhs, rhs, op) + + # Both sequence, comparisons other than == and != do not allow mixing different types of sequences + if ( + isinstance(lhs, Sequence) + and isinstance(rhs, Sequence) + and type(lhs) == type(rhs) + ): + unequal_found = False + comp_results = [] + mask = [] + for l, r in zip(lhs, rhs): + is_equal = equal(l, r) + mask.append(not_(or_(is_equal, unequal_found))) + unequal_found = not_(is_equal) + comp_results.append(_lte_gte(l, r, op)) + + result = any_(and_(r, m) for r, m in zip(comp_results, mask)) + + if len(lhs) != len(rhs): + # Ref https://docs.python.org/3/tutorial/datastructures.html#comparing-sequences-and-other-types + # If one sequence is an initial sub-sequence of the other, the shorter sequence is the smaller (lesser) one + has_valid_mask = any_(mask) + match op: + case "<": + length_result = len(lhs) < len(rhs) + case ">": + length_result = len(lhs) > len(rhs) + case "<=": + length_result = len(lhs) <= len(rhs) + case ">=": + length_result = len(lhs) >= len(rhs) + if type(has_valid_mask) == bool: + return result if has_valid_mask else length_result + else: + return select_(has_valid_mask, result, length_result) + else: + if op in {"<=", ">="}: + # If no unequal, return True + return select_(unequal_found, result, True) + else: + return result + else: + return native_lte_gte(lhs, rhs, op) + + +def greater_than(lhs, rhs): + return _lte_gte(lhs, rhs, ">") + + +def greater_equal(lhs, rhs): + return _lte_gte(lhs, rhs, ">=") + + +def less_than(lhs, rhs): + return _lte_gte(lhs, rhs, "<") + + +def less_equal(lhs, rhs): + return _lte_gte(lhs, rhs, "<=") + + +def _compare_dispatch(lhs, rhs, op): + """ + Dispatches the comparison operation between lhs and rhs based on the given operator. + + :param lhs: The left-hand side operand for the comparison. + :param rhs: The right-hand side operand for the comparison. + :param op: The comparison operator as a string. Supported operators are: + - "is", "is not": Python identity comparisons. + - "in", "not in": Membership tests. + - "==", "!=": Equality and inequality. + - "<", ">", "<=", ">=": Relational comparisons. + :return: The result of the comparison, which may be a boolean or a DSL-specific type. + :raises DSLRuntimeError: If the operator is not supported. + """ + match op: + # 'is' and 'is not' are pure python operators + case "is": + return lhs is rhs + case "is not": + return lhs is not rhs + case "in": + return in_(lhs, rhs) + case "not in": + return not_(in_(lhs, rhs)) + case "==": + return equal(lhs, rhs) + case "!=": + return not_equal(lhs, rhs) + case "<": + return less_than(lhs, rhs) + case ">": + return greater_than(lhs, rhs) + case ">=": + return greater_equal(lhs, rhs) + case "<=": + return less_equal(lhs, rhs) + case _: + raise DSLRuntimeError(f"Unsupported comparison operator: {op}") + + +def _compare_executor(left, comparators, ops): + # Fast path for single comparison + if len(comparators) == 1: + return _compare_dispatch(left, comparators[0], ops[0]) + + # Chain comparison, dispatch in a loop + result = True + current = left + for comparator, op in zip(comparators, ops): + cmp_result = _compare_dispatch(current, comparator, op) + result = and_(result, cmp_result) + current = comparator + + return result + + +def _builtin_redirector(fcn): + if fcn == builtins.max: + return max + elif fcn == builtins.min: + return min + elif fcn == builtins.any: + return any_ + elif fcn == builtins.all: + return all_ + else: + raise DSLRuntimeError(f"Unsupported built-in function: {fcn}") + + +# ============================================================================= +# Set the AST decorator +# ============================================================================= + +# Set the DSL specific functions +executor.set_functions( + is_dynamic_expression=is_dynamic_expression, + loop_execute_range_dynamic=_loop_execute_range_dynamic, + if_dynamic=_if_execute_dynamic, + while_dynamic=_while_execute_dynamic, + compare_executor=_compare_executor, + any_executor=any_, + all_executor=all_, + builtin_redirector=_builtin_redirector, +) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b4d8953d69b4100871a496623f051d60ab2a8d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/cutlass_ast_decorators.py @@ -0,0 +1,633 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import List, Tuple +from types import NoneType +from cutlass._mlir import ir +from cutlass._mlir.dialects import scf, arith +from cutlass._mlir.extras import types as T +from collections.abc import Sequence + +from ..base_dsl.dsl import is_dynamic_expression +from ..base_dsl.ast_helpers import * +from ..base_dsl.utils.logger import log +from ..base_dsl import typing as t +from ..base_dsl.typing import ( + Int32, + Float32, + Boolean, + Numeric, + get_mlir_types, + as_numeric, +) +from . import cutlass as cutlass_dsl +from .tree_utils import PyTreeDef, check_tree_equal + +# ============================================================================= +# AST Helpers +# ============================================================================= + + +class LoopUnroll(ir.Attribute): + def __init__(self, **kwargs): + valid_keys = set(["count", "full"]) + def to_mlir_attr(val): + if isinstance(val, bool): + return "true" if val else "false" + elif isinstance(val, int): + return f"{val} : i32" + else: + raise DSLNotImplemented(f"{type(val)} is not supported") + + cfg = {key: to_mlir_attr(kwargs[key]) for key in valid_keys if key in kwargs} + if kwargs.get("count", None) == 1: + cfg["disable"] = "true" + + unroll = "<" + ", ".join(f"{key} = {value}" for key, value in cfg.items()) + ">" + + super().__init__( + ir.Attribute.parse(f"#llvm.loop_annotation") + ) + + +class ScfGenerator: + """ + Encapsulates common scf dialect functionality: pack, unpack, and SCF execution. + """ + + def __init__(self): + pass + + @staticmethod + def _normalize_region_result_to_list(region_result: Any) -> List[Any]: + """ + Convert region_result to a list if it is not already a list + If region_result is a list, return it as is. + If region_result is None, return an empty list. + If region_result is not a list, return a list containing region_result as the only element. + """ + if region_result is None: + region_result_list = [] + elif not isinstance(region_result, list): + region_result_list = [region_result] + else: + region_result_list = region_result + return region_result_list + + @staticmethod + def _check_region_result(original_value, region_value, arg_name, op_type_name): + """ + Validate that a region result maintains the same type as the original value. + + This method checks for type consistency between the original value passed to a dynamic + SCF operation (like for, if, while) and the value returned from the operation's region. + + Args: + original_value: The value before entering the SCF operation region + region_value: The value returned from the SCF operation region + arg_name: Name of the argument being checked (for error reporting) + op_type_name: Type of SCF operation (e.g., 'for', 'if', 'while') for error reporting + + Raises: + DSLRuntimeError: If the region value has a different type than the original value. + The error includes suggestions for using compile-time control flow instead. + + Note: + This method performs relaxed type checking that allows inheritance relationships. + For example, a child class can be returned where a parent class was expected. + However, fundamental type changes (like None to non-None, different sequence types, + or different numeric types) are not allowed in dynamic SCF operations. + """ + + def get_type_name(value): + if isinstance(value, NoneType): + return "None" + elif isinstance(value, Sequence): + return f"{type(value).__name__}<{len(value)}>" + else: + return type(value).__name__ + + # Check for type mismatches + type_mismatch = False + old_type_name = None + new_type_name = None + + # Handle None type changes + if isinstance(original_value, NoneType) != isinstance(region_value, NoneType): + type_mismatch = True + old_type_name = get_type_name(original_value) + new_type_name = get_type_name(region_value) + # Handle sequence type/length changes + elif isinstance(original_value, Sequence) and isinstance( + region_value, Sequence + ): + if type(original_value) != type(region_value) or len(original_value) != len( + region_value + ): + type_mismatch = True + old_type_name = get_type_name(original_value) + new_type_name = get_type_name(region_value) + # Handle numeric type changes + elif isinstance( + original_value, (Numeric, ArithValue, ir.Value, int, float, bool) + ) or isinstance( + region_value, (Numeric, ArithValue, ir.Value, int, float, bool) + ): + try: + original_numeric = as_numeric(original_value) + region_numeric = as_numeric(region_value) + if original_numeric.dtype != region_numeric.dtype: + type_mismatch = True + old_type_name = original_numeric.dtype.__name__ + new_type_name = region_numeric.dtype.__name__ + except Exception: + pass + # Handle general type changes (relaxed for inheritance) + elif type(original_value) != type(region_value): + old_type = type(original_value) + new_type = type(region_value) + if not (issubclass(old_type, new_type) or issubclass(new_type, old_type)): + type_mismatch = True + old_type_name = old_type.__name__ + new_type_name = new_type.__name__ + + if type_mismatch: + raise DSLRuntimeError( + f"`{arg_name}` is {old_type_name} prior to this `{op_type_name}`, " + f"and update to {new_type_name} inside of this `{op_type_name}` is not supported.", + suggestion=( + f"Please avoid changing type inside a dynamic `{op_type_name}`, " + f"or change to compile-time control flow by marking this `{op_type_name}` with " + f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." + ), + ) + + def scf_execute_dynamic( + self, + op_type_name: str, + mix_iter_args: List[Any], + full_write_args_count: int, + mix_iter_arg_names: List[str], + create_op_func: Callable[[List[ir.Value]], ir.Operation], + region_builders: List[ + Callable[ + [ + "ir.Operation", + List["ir.Value"], # block_args + List["ir.Value"], # dyn_yield_ops + PyTreeDef, + List[Any], + int, + ], + Any, + ] + ], + # block_term_op_builder[region_builder] = scf_op_builder + # e.g. scf.ConditionOp for while loop + block_term_op_builder: Dict[Callable, Callable] = {}, + ) -> Any: + # 1) Unpack + ir_values, pytree_def = cutlass_dsl.unpack_to_irvalue( + mix_iter_args, op_type_name, full_write_args_count + ) + # 2) Create the SCF op + op = create_op_func(ir_values) + log().debug("Generated scf.%s \n[%s]", op_type_name, op) + + # 3) Build the regions + for i, builder in enumerate(region_builders): + region = op.regions[i] + block = region.blocks[0] + with ir.InsertionPoint(block): + block_args = list(block.arguments) + region_result = builder( + op, + block_args, + ir_values, + pytree_def, + mix_iter_args, + full_write_args_count, + ) + + # Use custom terminator if provided for this builder, otherwise use default YieldOp + if builder in block_term_op_builder: + # Use the provided terminator generator + block_term_op_builder[builder](region_result, full_write_args_count) + else: + # Normalize region_result + region_result_list = ScfGenerator._normalize_region_result_to_list( + region_result + ) + # For standard yield op, check result + for arg, result, name in zip( + mix_iter_args, + region_result_list, + mix_iter_arg_names, + ): + ScfGenerator._check_region_result( + arg, result, name, op_type_name + ) + + # Default behavior - generate YieldOp + region_values, yield_pytree_def = cutlass_dsl.unpack_to_irvalue( + region_result_list, op_type_name, full_write_args_count + ) + + mismatch = check_tree_equal(pytree_def, yield_pytree_def) + if mismatch != -1: + # Get arg name + filterd_arg_names = ( + cutlass_dsl.filter_readonly_frozen_dataclass_names( + mix_iter_args, mix_iter_arg_names, full_write_args_count + ) + ) + + raise DSLRuntimeError( + f"`{filterd_arg_names[mismatch]}` is structured different after this `{op_type_name}`.", + suggestion=( + f"Please avoid changing type structure inside a dynamic `{op_type_name}`, " + f"or change to compile-time control flow by marking this `{op_type_name}` with " + f"`{'range_constexpr' if op_type_name == 'for' else 'const_expr'}`." + ), + ) + + scf.YieldOp(region_values) + + log().debug("Completed scf.%s \n[%s]", op_type_name, op) + + # 4) Pack final results + final_results = cutlass_dsl.pack_from_irvalue( + op.results, pytree_def, mix_iter_args, full_write_args_count + ) + + # 5) Return in a nice pattern + if not final_results: + return + if len(final_results) == 1: + return final_results[0] + return final_results + + +def _attr_const_check(attr, expected_type, attr_name): + # Use strict type equality to prevent `bool` being accepted where `int` is required. + if is_dynamic_expression(attr) or type(attr) is not expected_type: + raise DSLRuntimeError( + f"loop attribute `{attr_name}` must be a Python value of type `{expected_type.__name__}`, got `{type(attr).__name__}`." + ) + + +def _loop_execute_range_dynamic( + func: Callable, + start: Any, + stop: Any, + step: Any, + mix_iter_args: List[Any] = [], + full_write_args_count: int = 0, + mix_iter_arg_names: List[str] = [], + unroll: int = -1, + unroll_full: bool = False, + prefetch_stages: int = None, +): + """ + Example: build an scf.for with optional unroll, using our universal helper. + """ + scf_gen = ScfGenerator() + + def create_for_op(dyn_yield_ops: List[ir.Value]): + for d in dyn_yield_ops: + if not isinstance(d, ir.Value): + raise DSLRuntimeError( + f"Invalid dyn_yield_ops: {dyn_yield_ops} \n\tExpected ir.Value, got {type(d)}" + ) + + # Convert Python ints or values to IR constants if needed + start_ = t.as_numeric(start) + stop_ = t.as_numeric(stop) + step_ = t.as_numeric(step) + assert start_ is not t.Int32, "Start is required for scf.for" + assert stop_ is not t.Int32, "Stop is required for scf.for" + assert step_ is not t.Int32, "Step is required for scf.for" + start_ = start_.ir_value() + stop_ = stop_.ir_value() + step_ = step_.ir_value() + + # Attributes must be pure Python value, add a check + _attr_const_check(unroll, int, "unroll") + _attr_const_check(unroll_full, bool, "unroll_full") + + # Possibly attach unroll attributes + unroll_attr = None + if unroll_full: + unroll_attr = LoopUnroll(full=True) + elif unroll != -1: + unroll_attr = LoopUnroll(count=unroll) + log().debug("Unroll attribute: %s", unroll_attr) + + prefetch_stages_attr = None + if prefetch_stages is not None: + _attr_const_check(prefetch_stages, int, "prefetch_stages") + if prefetch_stages >= 0: + prefetch_stages_attr = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), prefetch_stages + ) + else: + raise DSLRuntimeError( + f"loop attribute `prefetch_stages` must be non-negative, got `{prefetch_stages}`." + ) + log().debug("prefetch_stages attribute: %s", prefetch_stages_attr) + + log().debug( + "Creating scf.ForOp \n\t\tstart=%s: type : %s\n\t\tstop=%s: type : %s\n\t\tstep=%s: type : %s", + start_, + type(start_), + stop_, + type(stop_), + step_, + type(step_), + ) + # Create scf.ForOp, passing iteration args if any + try: + if not dyn_yield_ops: + for_op = scf.ForOp(start_, stop_, step_) + else: + for_op = scf.ForOp(start_, stop_, step_, list(dyn_yield_ops)) + except Exception as e: + yield_ops = "\n".join( + f"\t\t{i} => {d} : type : {type(d)}" + for i, d in enumerate(dyn_yield_ops) + ) + raise DSLRuntimeError( + f"Failed to create scf.ForOp \n\t\tstart={start_}: type : {type(start_)}" + f"\n\t\tstop={stop_}: type : {type(stop_)}\n\t\tstep={step_}: type : {type(step_)}" + f", \n\tdyn_yield_ops:\n{yield_ops}" + ) from e + + if unroll_attr is not None: + for_op.attributes["loop_annotation"] = unroll_attr + + if prefetch_stages_attr is not None: + for_op.attributes["cutlass.pipelining"] = prefetch_stages_attr + + return for_op + + def for_body_builder( + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + # scf.ForOp block_args are typically [induction_var, iter_args...] + # But MLIR also gives you op.induction_variable + iv = t.as_numeric(op.induction_variable) + log().debug( + "For body builder: %s block_args: %s full_write_args_count: %s", + iv, + block_args, + full_write_args_count, + ) + # block_args[1:] are iteration variables + func_args = [] + func_args.extend( + cutlass_dsl.pack_from_irvalue( + block_args[1:], pytree_def, mix_iter_args, full_write_args_count + ) + ) + if not func_args: + # No iteration arguments, or only the induction var + func(iv) + return [] # yield nothing + else: + updated_func_args = func(iv, *func_args) + return updated_func_args + + # Now call the universal SCF executor with a single region builder + return scf_gen.scf_execute_dynamic( + op_type_name="for", + mix_iter_args=mix_iter_args, + full_write_args_count=full_write_args_count, + mix_iter_arg_names=mix_iter_arg_names, + create_op_func=create_for_op, + region_builders=[for_body_builder], + ) + + +def _if_execute_dynamic( + pred: "ir.Value", + then_block: Callable, + else_block: Callable = None, + mix_yield_args: List[Any] = [], + full_write_args_count: int = 0, + mix_yield_arg_names: List[str] = [], + if_constexpr=None, # ignoring for brevity +): + """ + Build an scf.if with optional else, using our universal helper. + """ + scf_gen = ScfGenerator() + + def create_if_op(dyn_yield_ops: List[ir.Value]): + # Assume final result types match the dynamic yields + result_types = [arg.type for arg in dyn_yield_ops] + + pred_ = Boolean(pred) + + try: + if_op = scf.IfOp( + pred_.ir_value(), + hasElse=(else_block is not None), + results_=result_types, + ) + except Exception as e: + raise DSLRuntimeError( + f"Failed to create scf.IfOp \n\t\tpred={pred_}: type : {type(pred_)}" + ) from e + return if_op + + def then_builder( + if_op, + _, + dyn_yield_ops, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count + ) + ) + return then_block(*flat_args) + + region_builders = [then_builder] + + if else_block is not None: + + def else_builder( + if_op, + _, + dyn_yield_ops, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + dyn_yield_ops, pytree_def, mix_iter_args, full_write_args_count + ) + ) + return else_block(*flat_args) + + region_builders.append(else_builder) + + return scf_gen.scf_execute_dynamic( + op_type_name="if", + mix_iter_args=mix_yield_args, + full_write_args_count=full_write_args_count, + mix_iter_arg_names=mix_yield_arg_names, + create_op_func=create_if_op, + region_builders=region_builders, + ) + + +def _while_execute_dynamic( + while_before_block: Callable, + while_after_block: Callable = None, + write_args=[], + full_write_args_count=0, + write_args_names=[], +): + """ + Create and return an SCF WhileOp for dynamic loops. + Generate the dynamic loop body using SCF WhileOp. + + Args: + while_before_block: Function that returns (condition, updated_values) + while_after_block: Function that returns updated values + write_args: Values that are updated in the loop + + See create_while_function in ast_preprocessor.py for details on the input structure. + """ + log().debug("_while_execute_dynamic") + while_op_type_name = "while" + scf_gen = ScfGenerator() + + def create_while_op(dyn_yield_ops: List[ir.Value]): + # Create the while operation with the types from yield_args + result_types = [arg.type for arg in dyn_yield_ops] + try: + while_op = scf.WhileOp(result_types, dyn_yield_ops) + while_op.before.blocks.append(*result_types) + while_op.after.blocks.append(*result_types) + log().debug("[%s]", while_op) + return while_op + except Exception as e: + yield_ops = "\n".join( + f"\t\t{i} => {d} : type : {type(d)}" + for i, d in enumerate(dyn_yield_ops) + ) + raise DSLRuntimeError( + f"Failed to create scf.WhileOp with yield_ops:\n{yield_ops}" + ) from e + + def before_block_builder( + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + # Build the before (condition) block + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + block_args, pytree_def, mix_iter_args, full_write_args_count + ) + ) + + log().debug("before block args: %s", flat_args) + + cond, before_results = while_before_block(*flat_args) + + if not isinstance(before_results, (list, ir.OpResultList)): + before_results = [before_results] + + log().debug("cond [%s]", cond) + log().debug( + "before_results [%s]", + before_results, + ) + + return cond, before_results + + def before_block_terminator(cond_and_results, full_write_args_count): + # Generate a condition op instead of yield op + cond = cond_and_results[0] + before_result_list = ScfGenerator._normalize_region_result_to_list( + cond_and_results[1] + ) + ir_cond = as_numeric(cond).ir_value() + ir_results_list, pytree_def = cutlass_dsl.unpack_to_irvalue( + before_result_list, while_op_type_name, full_write_args_count + ) + log().debug( + "creating scf.ConditionOp with [%s], [%s]", + ir_cond, + ir_results_list, + ) + scf.ConditionOp(ir_cond, ir_results_list) + + def after_block_builder( + op, + block_args, + _, + pytree_def, + mix_iter_args, + full_write_args_count, + ): + # Build the after (body) block + flat_args = [] + flat_args.extend( + cutlass_dsl.pack_from_irvalue( + block_args, pytree_def, mix_iter_args, full_write_args_count + ) + ) + + log().debug("after block args: %s", flat_args) + + after_results = while_after_block(*flat_args) + + if not isinstance(after_results, (list, ir.OpResultList)): + after_results = [after_results] + + log().debug( + "after_results [%s]", + after_results, + ) + + return after_results + + # Call the universal SCF executor with two region builders + return scf_gen.scf_execute_dynamic( + op_type_name=while_op_type_name, + mix_iter_args=write_args, + full_write_args_count=full_write_args_count, + mix_iter_arg_names=write_args_names, + create_op_func=create_while_op, + region_builders=[before_block_builder, after_block_builder], + block_term_op_builder={ + before_block_builder: before_block_terminator + }, # Only customize the before block + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..599b72ea5c6b1d378480ceeb1d43d14fd58b569d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/CuTeDSL/cutlass_dsl/tree_utils.py @@ -0,0 +1,763 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from typing import Callable, Any, Iterable, Iterator, NamedTuple, Union, get_origin +import dataclasses +import itertools as it +from types import SimpleNamespace + +from ..base_dsl.typing import as_numeric, Numeric, Constexpr +from ..base_dsl._mlir_helpers.arith import ArithValue +from ..base_dsl.common import DSLBaseError +from .._mlir import ir + +# ============================================================================= +# Tree Utils +# ============================================================================= + + +class DSLTreeFlattenError(DSLBaseError): + """Exception raised when tree flattening fails due to unsupported types.""" + + def __init__(self, msg: str, type_str: str): + super().__init__(msg) + self.type_str = type_str + + +def unzip2(pairs: Iterable[tuple[Any, Any]]) -> tuple[list[Any], list[Any]]: + """Unzip a sequence of pairs into two lists.""" + lst1, lst2 = [], [] + for x1, x2 in pairs: + lst1.append(x1) + lst2.append(x2) + return lst1, lst2 + + +def get_fully_qualified_class_name(x: Any) -> str: + """ + Get the fully qualified class name of an object. + + Args: + x: Any object + + Returns: + str: Fully qualified class name in format 'module.class_name' + + Example: + >>> get_fully_qualified_class_name([1, 2, 3]) + 'builtins.list' + """ + return f"{x.__class__.__module__}.{x.__class__.__qualname__}" + + +def is_frozen_dataclass(obj_or_cls: Any) -> bool: + """ + Check if an object or class is a frozen dataclass. + + Args: + obj_or_cls: Either a dataclass instance or class + + Returns: + bool: True if the object/class is a dataclass declared with frozen=True, + False otherwise + + Example: + >>> from dataclasses import dataclass + >>> @dataclass(frozen=True) + ... class Point: + ... x: int + ... y: int + >>> is_frozen_dataclass(Point) + True + >>> is_frozen_dataclass(Point(1, 2)) + True + """ + cls = obj_or_cls if isinstance(obj_or_cls, type) else obj_or_cls.__class__ + + return ( + dataclasses.is_dataclass(cls) + and getattr(cls, "__dataclass_params__", None) is not None + and cls.__dataclass_params__.frozen + ) + + +def is_dynamic_expression(x: Any) -> bool: + """ + Check if an object implements the DynamicExpression protocol. + + Objects implementing this protocol must have both `__extract_mlir_values__` + and `__new_from_mlir_values__` methods. + + Args: + x: Any object to check + + Returns: + bool: True if the object implements the DynamicExpression protocol, + False otherwise + """ + return all( + hasattr(x, attr) + for attr in ("__extract_mlir_values__", "__new_from_mlir_values__") + ) + + +def is_constexpr_field(field: dataclasses.Field) -> bool: + """ + Check if a field is a constexpr field. + """ + if field.type is Constexpr: + return True + elif get_origin(field.type) is Constexpr: + return True + return False + + +# ============================================================================= +# PyTreeDef +# ============================================================================= + +class NodeType(NamedTuple): + """ + Represents a node in a pytree structure. + + Attributes: + name: String representation of the node type + to_iterable: Function to convert node to iterable form + from_iterable: Function to reconstruct node from iterable form + """ + name: str + to_iterable: Callable + from_iterable: Callable + + +class PyTreeDef(NamedTuple): + """ + Represents the structure definition of a pytree. + + Attributes: + node_type: The type of this node + node_metadata: SimpleNamespace metadata associated with this node + child_treedefs: Tuple of child tree definitions + """ + node_type: NodeType + node_metadata: SimpleNamespace + child_treedefs: tuple["PyTreeDef", ...] + + +@dataclasses.dataclass(frozen=True) +class Leaf: + """ + Represents a leaf node in a pytree structure. + + Attributes: + is_numeric: Whether this leaf contains a `Numeric` value + is_none: Whether this leaf represents None + node_metadata: SimpleNamespace metadata associated with this leaf + ir_type_str: String representation of the IR type + """ + is_numeric: bool = False + is_none: bool = False + node_metadata: SimpleNamespace = None + ir_type_str: str = None + + +# ============================================================================= +# Default to_iterable and from_iterable +# ============================================================================= + + +def extract_dataclass_members(x: Any) -> tuple[list[str], list[Any]]: + """ + Extract non-method, non-function attributes from a dataclass instance. + + Args: + x: A dataclass instance + + Returns: + tuple: (field_names, field_values) lists + """ + fields = [field.name for field in dataclasses.fields(x)] + + # If the dataclass has extra fields, raise an error + for k in x.__dict__.keys(): + if k not in fields: + raise DSLTreeFlattenError( + f"`{x}` has extra field `{k}`", + type_str=get_fully_qualified_class_name(x), + ) + + if not fields: + return [], [] + + # record constexpr fields + members = [] + constexpr_fields = [] + for field in dataclasses.fields(x): + if is_constexpr_field(field): + constexpr_fields.append(field.name) + fields.remove(field.name) + v = getattr(x, field.name) + if is_dynamic_expression(v): + raise DSLTreeFlattenError( + f"`{x}` has dynamic expression field `{field.name}` with a Constexpr type annotation `{field.type}`", + type_str=get_fully_qualified_class_name(x), + ) + else: + members.append(getattr(x, field.name)) + + return fields, members, constexpr_fields + + +def default_dataclass_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dataclass instance to iterable form for tree flattening. + + Extracts all non-method, non-function attributes that don't start with '__' + and returns them along with metadata about the dataclass. + + Args: + x: A dataclass instance + + Returns: + tuple: (metadata, members) where metadata contains type info and field names, + and members is the list of attribute values + """ + fields, members, constexpr_fields = extract_dataclass_members(x) + + metadata = SimpleNamespace( + type_str=get_fully_qualified_class_name(x), + fields=fields, + constexpr_fields=constexpr_fields, + original_obj=x, + ) + return metadata, members + + +def set_dataclass_attributes( + instance: Any, + fields: list[str], + values: Iterable[Any], + constexpr_fields: list[str], +) -> Any: + """ + Set attributes on a dataclass instance. + + Args: + instance: The dataclass instance + fields: List of field names + values: Iterable of field values + is_frozen: Whether the dataclass is frozen + + Returns: + The instance with attributes set + """ + if not fields: + return instance + + kwargs = dict(zip(fields, values)) + for field in constexpr_fields: + kwargs[field] = getattr(instance, field) + return dataclasses.replace(instance, **kwargs) + +def default_dataclass_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dataclass instance from iterable form. + + Handles both regular and frozen dataclasses appropriately. + + Args: + metadata: Metadata containing type information and field names + children: Iterable of attribute values to reconstruct the instance + + Returns: + The reconstructed dataclass instance + """ + instance = metadata.original_obj + + new_instance = set_dataclass_attributes( + instance, metadata.fields, children, metadata.constexpr_fields + ) + metadata.original_obj = new_instance + return new_instance + + +def dynamic_expression_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dynamic expression to iterable form. + + Uses the object's `__extract_mlir_values__` method to extract MLIR values. + + Args: + x: A dynamic expression object + + Returns: + tuple: (metadata, mlir_values) where metadata marks this as a dynamic expression + and mlir_values are the extracted MLIR values + """ + return ( + SimpleNamespace(is_dynamic_expression=1, original_obj=x), + x.__extract_mlir_values__(), + ) + + +def dynamic_expression_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dynamic expression from iterable form. + + Uses the object's `__new_from_mlir_values__` method to reconstruct from MLIR values. + + Args: + metadata: Metadata containing the original object + children: Iterable of MLIR values to reconstruct from + + Returns: + The reconstructed dynamic expression object + """ + return metadata.original_obj.__new_from_mlir_values__(list(children)) + + +def default_dict_to_iterable(x: Any) -> tuple[SimpleNamespace, list[Any]]: + """ + Convert a dict to iterable form. + """ + if isinstance(x, SimpleNamespace): + keys = list(x.__dict__.keys()) + values = list(x.__dict__.values()) + else: + keys = list(x.keys()) + values = list(x.values()) + + return ( + SimpleNamespace( + type_str=get_fully_qualified_class_name(x), original_obj=x, fields=keys + ), + values, + ) + + +def default_dict_from_iterable( + metadata: SimpleNamespace, children: Iterable[Any] +) -> Any: + """ + Reconstruct a dict from iterable form. + """ + instance = metadata.original_obj + fields = metadata.fields + is_simple_namespace = isinstance(instance, SimpleNamespace) + + for k, v in zip(fields, children): + if is_simple_namespace: + setattr(instance, k, v) + else: + instance[k] = v + + return instance + + +# ============================================================================= +# Register pytree nodes +# ============================================================================= + +_node_types: dict[type, NodeType] = {} + + +def register_pytree_node(ty: type, to_iter: Callable, from_iter: Callable) -> NodeType: + """ + Register a new node type for pytree operations. + + Args: + ty: The type to register + to_iter: Function to convert instances of this type to iterable form + from_iter: Function to reconstruct instances of this type from iterable form + + Returns: + NodeType: The created NodeType instance + """ + nt = NodeType(str(ty), to_iter, from_iter) + _node_types[ty] = nt + return nt + + +def register_default_node_types() -> None: + """Register default node types for pytree operations.""" + default_registrations = [ + ( + tuple, + lambda t: (SimpleNamespace(length=len(t)), list(t)), + lambda _, xs: tuple(xs), + ), + ( + list, + lambda l: (SimpleNamespace(length=len(l)), list(l)), + lambda _, xs: list(xs), + ), + ( + dict, + default_dict_to_iterable, + default_dict_from_iterable, + ), + ( + SimpleNamespace, + default_dict_to_iterable, + default_dict_from_iterable, + ), + ] + + for ty, to_iter, from_iter in default_registrations: + register_pytree_node(ty, to_iter, from_iter) + + +# Initialize default registrations +register_default_node_types() + + +# ============================================================================= +# tree_flatten and tree_unflatten +# ============================================================================= + +""" +Behavior of tree_flatten and tree_unflatten, for example: + +```python + a = (1, 2, 3) + b = MyClass(a=1, b =[1,2,3]) +``` + +yields the following tree: + +```python + tree_a = PyTreeDef(type = 'tuple', + metadata = {length = 3}, + children = [ + Leaf(type = int), + Leaf(type = int), + Leaf(type = int), + ], + ) + flattened_a = [1, 2, 3] + tree_b = PyTreeDef(type = 'MyClass', + metadata = {fields = ['a','b']}, + children = [ + PyTreeDef(type = `list`, + metadata = {length = 3}, + children = [ + Leaf(type=`int`), + Leaf(type=`int`), + Leaf(type=`int`), + ], + ), + Leaf(type=int), + ], + ) + flattened_b = [1, 1, 2, 3] +``` + +Passing the flattened values and PyTreeDef to tree_unflatten to reconstruct the original structure. + +``` python + unflattened_a = tree_unflatten(tree_a, flattened_a) + unflattened_b = tree_unflatten(tree_b, flattened_b) +``` + +yields the following structure: + +``` python + unflattened_a = (1, 2, 3) + unflattened_b = MyClass(a=1, b =[1,2,3]) +``` + +unflattened_a should be structurally identical to a, and unflattened_b should be structurally identical to b. + +""" + + +def tree_flatten(x: Any) -> tuple[list[Any], PyTreeDef]: + """ + Flatten a nested structure into a flat list of values and a tree definition. + + This function recursively traverses nested data structures (trees) and + flattens them into a linear list of leaf values, while preserving the + structure information in a PyTreeDef. + + Args: + x: The nested structure to flatten + + Returns: + tuple: (flat_values, treedef) where flat_values is a list of leaf values + and treedef is the tree structure definition + + Raises: + DSLTreeFlattenError: If the structure contains unsupported types + + Example: + >>> tree_flatten([1, [2, 3], 4]) + ([1, 2, 3, 4], PyTreeDef(...)) + """ + children_iter, treedef = _tree_flatten(x) + return list(children_iter), treedef + + +def get_registered_node_types_or_insert(x: Any) -> NodeType | None: + """ + Get the registered node type for an object, registering it if necessary. + + This function checks if a type is already registered for pytree operations. + If not, it automatically registers the type based on its characteristics: + - Dynamic expressions get registered with dynamic expression handlers + - Dataclasses get registered with default dataclass handlers + + Args: + x: The object to get or register a node type for + + Returns: + NodeType or None: The registered node type, or None if the type + cannot be registered + """ + node_type = _node_types.get(type(x)) + if node_type: + return node_type + elif is_dynamic_expression(x): + # If a class implements DynamicExpression protocol, register it before default dataclass one + return register_pytree_node( + type(x), dynamic_expression_to_iterable, dynamic_expression_from_iterable + ) + elif dataclasses.is_dataclass(x): + return register_pytree_node( + type(x), default_dataclass_to_iterable, default_dataclass_from_iterable + ) + else: + return None + + +def create_leaf_for_value( + x: Any, + is_numeric: bool = False, + is_none: bool = False, + node_metadata: SimpleNamespace = None, + ir_type_str: str = None, +) -> Leaf: + """ + Create a Leaf node for a given value. + + Args: + x: The value to create a leaf for + is_numeric: Whether this is a numeric value + is_none: Whether this represents None + node_metadata: Optional metadata + ir_type_str: Optional IR type string + + Returns: + Leaf: The created leaf node + """ + return Leaf( + is_numeric=is_numeric, + is_none=is_none, + node_metadata=node_metadata, + ir_type_str=ir_type_str or (str(x.type) if hasattr(x, "type") else None), + ) + + +def _tree_flatten(x: Any) -> tuple[Iterable[Any], PyTreeDef | Leaf]: + """ + Internal function to flatten a tree structure. + + This is the core implementation of tree flattening that handles different + types of objects including None, ArithValue, ir.Value, Numeric types, + and registered pytree node types. + + Args: + x: The object to flatten + + Returns: + tuple: (flattened_values, treedef) where flattened_values is an iterable + of leaf values and treedef is the tree structure + + Raises: + DSLTreeFlattenError: If the object type is not supported + """ + match x: + case None: + return [], create_leaf_for_value(x, is_none=True) + + case ArithValue() if is_dynamic_expression(x): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) + + case ArithValue(): + return [x], create_leaf_for_value(x, is_numeric=True) + + case ir.Value(): + return [x], create_leaf_for_value(x) + + case Numeric(): + v = x.__extract_mlir_values__() + return v, create_leaf_for_value( + x, + node_metadata=SimpleNamespace(is_dynamic_expression=1, original_obj=x), + ir_type_str=str(v[0].type), + ) + + case _: + node_type = get_registered_node_types_or_insert(x) + if node_type: + node_metadata, children = node_type.to_iterable(x) + children_flat, child_trees = unzip2(map(_tree_flatten, children)) + flattened = it.chain.from_iterable(children_flat) + return flattened, PyTreeDef( + node_type, node_metadata, tuple(child_trees) + ) + + # Try to convert to numeric + try: + nval = as_numeric(x).ir_value() + return [nval], create_leaf_for_value(nval, is_numeric=True) + except Exception: + raise DSLTreeFlattenError( + "Flatten Error", get_fully_qualified_class_name(x) + ) + + +def tree_unflatten(treedef: PyTreeDef, xs: list[Any]) -> Any: + """ + Reconstruct a nested structure from a flat list of values and tree definition. + + This is the inverse operation of tree_flatten. It takes the flattened + values and the tree structure definition to reconstruct the original + nested structure. + + Args: + treedef: The tree structure definition from tree_flatten + xs: List of flat values to reconstruct from + + Returns: + The reconstructed nested structure + + Example: + >>> flat_values, treedef = tree_flatten([1, [2, 3], 4]) + >>> tree_unflatten(treedef, flat_values) + [1, [2, 3], 4] + """ + return _tree_unflatten(treedef, iter(xs)) + + +def _tree_unflatten(treedef: PyTreeDef | Leaf, xs: Iterator[Any]) -> Any: + """ + Internal function to reconstruct a tree structure. + + This is the core implementation of tree unflattening that handles + different types of tree definitions including Leaf nodes and PyTreeDef nodes. + + Args: + treedef: The tree structure definition + xs: Iterator of flat values to reconstruct from + + Returns: + The reconstructed object + """ + match treedef: + case Leaf(is_none=True): + return None + + case Leaf( + node_metadata=metadata + ) if metadata and metadata.is_dynamic_expression: + return metadata.original_obj.__new_from_mlir_values__([next(xs)]) + + case Leaf(is_numeric=True): + return as_numeric(next(xs)) + + case Leaf(): + return next(xs) + + case PyTreeDef(): + children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) + return treedef.node_type.from_iterable(treedef.node_metadata, children) + + +def _check_tree_equal(lhs: Union[PyTreeDef, Leaf], rhs: Union[PyTreeDef, Leaf]) -> bool: + """ + Check if two tree definitions are structurally equal. + + This is a helper function for check_tree_equal that recursively compares + tree structures. + + Args: + lhs: Left tree definition (PyTreeDef or Leaf) + rhs: Right tree definition (PyTreeDef or Leaf) + + Returns: + bool: True if the trees are structurally equal, False otherwise + """ + match (lhs, rhs): + case (Leaf(), Leaf()): + return lhs.is_none == rhs.is_none and lhs.ir_type_str == rhs.ir_type_str + + case (PyTreeDef(), PyTreeDef()): + lhs_metadata = lhs.node_metadata + rhs_metadata = rhs.node_metadata + + lhs_fields = getattr(lhs_metadata, "fields", []) + rhs_fields = getattr(rhs_metadata, "fields", []) + lhs_constexpr_fields = getattr(lhs_metadata, "constexpr_fields", []) + rhs_constexpr_fields = getattr(rhs_metadata, "constexpr_fields", []) + + return ( + lhs.node_type == rhs.node_type + and lhs_fields == rhs_fields + and lhs_constexpr_fields == rhs_constexpr_fields + and len(lhs.child_treedefs) == len(rhs.child_treedefs) + and all(map(_check_tree_equal, lhs.child_treedefs, rhs.child_treedefs)) + ) + + case _: + return False + + +def check_tree_equal(lhs: PyTreeDef, rhs: PyTreeDef) -> int: + """ + Check if two tree definitions are equal and return the index of first difference. + + This function compares two tree definitions and returns the index of the + first child that differs, or -1 if they are completely equal. + + Args: + lhs: Left tree definition + rhs: Right tree definition + + Returns: + int: Index of the first differing child, or -1 if trees are equal + + Example: + >>> treedef1 = tree_flatten([1, [2, 3]])[1] + >>> treedef2 = tree_flatten([1, [2, 4]])[1] + >>> check_tree_equal(treedef1, treedef2) + 1 # The second child differs + """ + assert len(lhs.child_treedefs) == len(rhs.child_treedefs) + + def find_first_difference( + index_and_pair: tuple[int, tuple[PyTreeDef, PyTreeDef]] + ) -> int: + index, (l, r) = index_and_pair + return index if not _check_tree_equal(l, r) else -1 + + differences = map( + find_first_difference, enumerate(zip(lhs.child_treedefs, rhs.child_treedefs)) + ) + return next((diff for diff in differences if diff != -1), -1) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdd259c0203aaca3c7a7e31e64a576630f369a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/__init__.py @@ -0,0 +1,213 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +import logging +import os +import sys + +import cutlass_library + + +def _cuda_install_path_from_nvcc() -> str: + import subprocess + # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC + result = subprocess.run(['/usr/bin/which', 'nvcc'], capture_output=True) + if result.returncode != 0: + raise Exception(f'Unable to find nvcc via `which` utility.') + + cuda_install_path = result.stdout.decode('utf-8').split('/bin/nvcc')[0] + if not os.path.isdir(cuda_install_path): + raise Exception(f'Environment variable "CUDA_INSTALL_PATH" is not defined, ' + f'and default path of {cuda_install_path} does not exist.') + + return cuda_install_path + + +CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path) + +# Alias CUTLASS_PATH as source_path +source_path = CUTLASS_PATH + +_NVCC_VERSION = None +def nvcc_version(): + global _NVCC_VERSION + if _NVCC_VERSION is None: + import subprocess + + # Attempt to get NVCC version + result = subprocess.run(['nvcc', '--version'], capture_output=True) + if result.returncode != 0: + raise Exception('Unable to run `nvcc --version') + _NVCC_VERSION = str(result.stdout).split(" release ")[-1].split(",")[0] + return _NVCC_VERSION + +_CUDA_INSTALL_PATH = None +def cuda_install_path(): + """ + Helper method for on-demand fetching of the CUDA installation path. This allows + the import of CUTLASS to proceed even if NVCC is not available, preferring to + raise this error only when an operation that needs NVCC is being performed. + """ + global _CUDA_INSTALL_PATH + if _CUDA_INSTALL_PATH is None: + _CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc()) + return _CUDA_INSTALL_PATH + +CACHE_FILE = "compiled_cache.db" + +from cutlass_library import ( + DataType, + EpilogueScheduleType, + KernelScheduleType, + MathOperation, + LayoutType, + OpcodeClass, + TileDescription, + TileSchedulerType, +) + +this = sys.modules[__name__] +this.logger = logging.getLogger(__name__) + +# RMM is only supported for Python 3.9+ +if (sys.version_info.major == 3 and sys.version_info.minor > 8) or sys.version_info.major > 3: + try: + import rmm + this.use_rmm = True + except ImportError: + this.use_rmm = False +else: + this.use_rmm = False + + +def set_log_level(level: int): + """ + Sets the log level + + :param log_level: severity of logging level to use. See https://docs.python.org/3/library/logging.html#logging-levels for options + :type log_level: int + """ + this.logger.setLevel(level) + +set_log_level(logging.ERROR) + +from cutlass_cppgen.library_defaults import OptionRegistry +from cutlass_cppgen.backend.utils.device import device_cc + +this._option_registry = None +def get_option_registry(): + """ + Helper method for on-demand initialization of the options registry. This avoids building + the registry when CUTLASS is imported. + """ + if this._option_registry is None: + this.logger.info("Initializing option registry") + this._option_registry = OptionRegistry(device_cc()) + return this._option_registry + +this.__version__ = '4.2.1' + +from cutlass_cppgen.backend import create_memory_pool +from cutlass_cppgen.emit.pytorch import pytorch +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad +from cutlass_cppgen.op.gemm_grouped import GroupedGemm +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.backend.evt.ir.tensor import Tensor +from cutlass_cppgen.utils.lazy_import import lazy_import + + +this.memory_pool = None +def get_memory_pool(): + """" + Helper method for on-demand memory pool. This avoids allocating the memory pool unnecessarily + whe CUTLASS is imported. + """ + if this.use_rmm and this.memory_pool is None: + this.memory_pool = create_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32) + return this.memory_pool + + +base_cuda = lazy_import("cuda") +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") + +this._device_id = None +this._nvcc_version = None + +def check_cuda_versions(): + # Strip any additional information from the CUDA version + _cuda_version = base_cuda.__version__.split("rc")[0] + # Check that Python CUDA version exceeds NVCC version + this._nvcc_version = nvcc_version() + _cuda_list = _cuda_version.split('.') + _nvcc_list = this._nvcc_version.split('.') + for val_cuda, val_nvcc in zip(_cuda_list, _nvcc_list): + if int(val_cuda) < int(val_nvcc): + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") + + if len(_nvcc_list) > len(_cuda_list): + if len(_nvcc_list) != len(_cuda_list) + 1: + raise Exception(f"Malformatted NVCC version of {this._nvcc_version}") + if _nvcc_list[:-1] == _cuda_list and int(_nvcc_list[-1]) != 0: + raise Exception(f"Python CUDA version of {_cuda_version} must be greater than or equal to NVCC version of {this._nvcc_version}") + +def initialize_cuda_context(): + check_cuda_versions() + + if this._device_id is not None: + return + + if this.use_rmm: + # This also covers initializing the CUDA context + get_memory_pool() + + device_id = os.getenv("CUTLASS_CUDA_DEVICE_ID") + if device_id is None: + if not this.use_rmm: + # Manually call cuInit() and create context by making a runtime API call + err, = cudart.cudaFree(0) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + + err, device_count = cuda.cuDeviceGetCount() + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"cuDeviceGetCount failed with error {err}") + if device_count <= 0: + raise Exception("No CUDA devices found") + device_id = 0 + + this._device_id = int(device_id) + + +def device_id() -> int: + initialize_cuda_context() + return this._device_id diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59cfaf7154687fa3a971f2221f0cce2130ff1a4f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/__init__.py @@ -0,0 +1,48 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.arguments import * +from cutlass_cppgen.backend.c_types import * +from cutlass_cppgen.backend.compiler import ArtifactManager +from cutlass_cppgen.backend.conv2d_operation import * +from cutlass_cppgen.backend.epilogue import * +from cutlass_cppgen.backend.frontend import * +from cutlass_cppgen.backend.gemm_operation import * +from cutlass_cppgen.backend.library import * +from cutlass_cppgen.backend.memory_manager import PoolMemoryManager, create_memory_pool +from cutlass_cppgen.backend.operation import * +from cutlass_cppgen.backend.reduction_operation import * +from cutlass_cppgen.backend.type_hint import * +from cutlass_cppgen.backend.utils import * +from cutlass_cppgen.backend.utils.device import device_cc + +compiler = ArtifactManager() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..b1b0656a89a8b0a42b864429810b74bc433582d4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/arguments.py @@ -0,0 +1,136 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from math import prod +from typing import Union + +from cutlass_cppgen.utils.lazy_import import lazy_import + +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +import numpy as np + +import cutlass_cppgen +from cutlass_cppgen.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend +from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor + + +class ArgumentBase: + """ + Base class for operation arguments + """ + + def __init__( + self, + A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]", + **kwargs, + ) -> None: + # tensor_C can be interpreted as the bias with bias=True in keyword args + self.bias = kwargs.get("bias", False) + + self.stream = kwargs.get("stream", cuda.CUstream(0)) + + # RMM buffers used to track tensor lifetime + self.buffers = {} + # Host tensor to copy the computed result back + self.host_tensors = {} + + self.ptr_A = self.tensor_to_ptr(A, "A") + self.ptr_B = self.tensor_to_ptr(B, "B") + self.ptr_C = self.tensor_to_ptr(C, "C") + self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True) + if C is not None: + if not isinstance(C, cuda.CUdeviceptr): + self.tensor_c_numel = prod(C.shape) + + def tensor_to_ptr(self, tensor, name, is_output=False): + """ + Convert and remember the input tensor to cuda.CUdeviceptr used by cuda python + For numpy.ndarray, it also remembers the host buffer for synchronization + """ + if tensor is None: + return cuda.CUdeviceptr(0) + if is_numpy_tensor(tensor): + if is_output: + assert name + self.buffers[name] = NumpyFrontend.argument(tensor, is_output) + if is_output: + self.host_tensors[name] = tensor + return self.buffers[name].ptr + elif is_torch_tensor(tensor): + return TorchFrontend.argument(tensor) + elif isinstance(tensor, cuda.CUdeviceptr): + return tensor + elif is_cupy_tensor(tensor): + return CupyFrontend.argument(tensor) + else: + raise TypeError("Unsupported Frontend. Only support numpy and torch") + + def sync(self, stream_sync=True): + if stream_sync: + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + for key in self.host_tensors.keys(): + host_tensor = self.host_tensors[key] + (err,) = cuda.cuMemcpyDtoH( + host_tensor, + self.buffers[key].ptr, + host_tensor.size * host_tensor.itemsize, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + self.free() + + def free(self): + """ + Frees allocated device-side memory + """ + # Free any device memory allocated manually + if not cutlass_cppgen.use_rmm: + for name, buf in self.buffers.items(): + if isinstance(buf, DevicePtrWrapper): + err, = cudart.cudaFree(buf.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + + if hasattr(self, "workspace_buffer") and isinstance(self.workspace_buffer, DevicePtrWrapper): + err, = cudart.cudaFree(self.workspace_buffer.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + del self.workspace_buffer diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py new file mode 100644 index 0000000000000000000000000000000000000000..3f515aa38439e4b2e1392659d188cbe6a68e0481 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/c_types.py @@ -0,0 +1,625 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes + +from cutlass_library import ( + DataType, + KernelScheduleType, + TileSchedulerType +) +from cutlass_cppgen.backend.library import DataTypeSizeBytes + + +class GemmCoord_(ctypes.Structure): + _fields_ = [ + ("m", ctypes.c_int), + ("n", ctypes.c_int), + ("k", ctypes.c_int) + ] + + def __init__(self, m, n, k) -> None: + self.m = m + self.n = n + self.k = k + + +class GemmCoordBatched_(ctypes.Structure): + """ + Wrapper around a GemmCoord that also contains batch count. This is used for encoding + batched GEMM inputs to CUTLASS 3 GEMMs. + """ + + _fields_ = [ + ("m", ctypes.c_int), + ("n", ctypes.c_int), + ("k", ctypes.c_int), + ("batch_count", ctypes.c_int) + ] + + def __init__(self, gemm_coord, batch_count) -> None: + self.m = gemm_coord.m + self.n = gemm_coord.n + self.k = gemm_coord.k + self.batch_count = batch_count + + +class MatrixCoord_(ctypes.Structure): + _fields_ = [ + ("row", ctypes.c_int), + ("column", ctypes.c_int) + ] + + +class dim3_(ctypes.Structure): + _fields_ = [ + ("x", ctypes.c_int), + ("y", ctypes.c_int), + ("z", ctypes.c_int) + ] + + +class StrideBatched_(ctypes.Structure): + """ + CUTLASS 3.0 strides for operands contain one static dimension and two variable dimensions. The + variable dimensions represent the stride along non-unit-stride dimension of the row/column major + layout, and the batch stride. This structure encodes the two variable dimensions. + """ + _fields_ = [ + ("major_stride", ctypes.c_int64), + ("batch_stride", ctypes.c_int64) + ] + + + +class GenericMainloopArguments3x_(ctypes.Structure): + """ + Structure representing the superset of possible mainloop arguments. + This structure should not be passed to kernels directly, but, rather, + be used as an input to one of the more specific schedule arguments, which + will each select those arguments relevant to the particular schedule. + """ + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ("mma_promotion_interval", ctypes.c_int) + ] + + +class _PersistentTileSchedulerArguments(ctypes.Structure): + _fields_ = [ + ("max_swizzle_size", ctypes.c_int), + ("raster_order_option", ctypes.c_int), + ] + + +class _PersistentTileSchedulerStreamKArguments(ctypes.Structure): + _fields_ = [ + ("splits", ctypes.c_int), + ("max_swizzle_size", ctypes.c_int), + ("raster_order_option", ctypes.c_int), + ("reduction_mode", ctypes.c_int), + ("decomposition_mode", ctypes.c_int), + ] + + +def get_tile_scheduler_arguments_3x( + tile_scheduler: TileSchedulerType, + splits: int = 1): + max_swizzle_size = 1 + raster_order_option = 0 # Heuristic + if tile_scheduler in [TileSchedulerType.Default, TileSchedulerType.Persistent]: + return _PersistentTileSchedulerArguments( + max_swizzle_size, + raster_order_option, + ) + elif tile_scheduler == TileSchedulerType.StreamK: + reduction_mode = 0 # Deterministic + decomposition_mode = 0 # Heuristic + return _PersistentTileSchedulerStreamKArguments( + splits, + max_swizzle_size, + raster_order_option, + reduction_mode, + decomposition_mode, + ) + + +def get_mainloop_arguments_3x( + kernel_schedule: KernelScheduleType, + element_A, + element_B, + alignment_A: int, + alignment_B: int) -> ctypes.Structure: + """ + Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters. + + :param kernel_schedule: type of kernel schedule to be used in the mainloop + :type kernel_schedule: cutlass_library.KernelScheduleType + :param element_A: data type of operand A + :param element_B: data type of operand B + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + + :returns: ctypes structure to be used for the 3.x kernel's mainloop parameters + :rtype: ctypes.Structure + """ + class _MainloopArgumentsTma(ctypes.Structure): + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ("mma_promotion_interval", ctypes.c_int) + ] + + @staticmethod + def from_generic_mainloop_args(args: GenericMainloopArguments3x_): + return _MainloopArgumentsTma( + args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, + args.mma_promotion_interval + ) + + class _MainloopArgumentsMultistage(ctypes.Structure): + _fields_ = [ + ("ptr_A", ctypes.c_void_p), + ("stride_A", StrideBatched_), + ("ptr_B", ctypes.c_void_p), + ("stride_B", StrideBatched_), + ] + + @staticmethod + def from_generic_mainloop_args(args: GenericMainloopArguments3x_): + return _MainloopArgumentsMultistage( + args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, + ) + + # Currently all 3.x kernels (CpAsync and Tma) have the same argument structure. + # Should that become not the case, this is the place to return custom ctypes + # structures based on selected kernel schedule. + return _MainloopArgumentsTma + + +def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor, scheduler_args, default_epilogue): + if not default_epilogue and hasattr(epilogue_functor, "epilogue_type_evt"): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type_evt + else: + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + if hasattr(epilogue_functor, "visitor"): + class _EpilogueArguments(ctypes.Structure): + _fields_ = [ + ("epilogue", _EpilogueOutputOpParams), + ("arg_C", epilogue_functor.arg_c_type), + ("arg_D", epilogue_functor.arg_d_type) + ] + + def __init__(self, output_op, ptr_c, stride_c, ptr_d, stride_d) -> None: + self.epilogue = output_op + self.arg_C = epilogue_functor.arg_c_type(ptr_c) + self.arg_D = epilogue_functor.arg_d_type(ptr_d) + else: + class _EpilogueArguments(ctypes.Structure): + _fields_ = [ + ("epilogue", _EpilogueOutputOpParams), + ("ptr_C", ctypes.c_void_p), + ("stride_C", StrideBatched_), + ("ptr_D", ctypes.c_void_p), + ("stride_D", StrideBatched_), + ] + + class _HardwareInfo(ctypes.Structure): + _fields_ = [ + ("device_id", ctypes.c_int), + ("sm_count", ctypes.c_int), + ("max_active_clusters", ctypes.c_int), + ("cluster_shape", dim3_), + ("cluster_shape_fallback", dim3_), + ] + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoordBatched_), + ("mainloop", mainloop_arguments), + ("epilogue", _EpilogueArguments), + ("hw_info", _HardwareInfo), + ("scheduler", type(scheduler_args)), + ] + + return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo + + +def get_gemm_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + # Arguments from UniversalArgumentsBase + ("mode", ctypes.c_int), + ("problem_size", GemmCoord_), + ("batch_count", ctypes.c_int), + ("batch_stride_D", ctypes.c_longlong), + # Remaining arguments + ("epilogue", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("batch_stride_A", ctypes.c_longlong), + ("batch_stride_B", ctypes.c_longlong), + ("batch_stride_C", ctypes.c_longlong), + ("stride_a", ctypes.c_longlong), + ("stride_b", ctypes.c_longlong), + ("stride_c", ctypes.c_longlong), + ("stride_d", ctypes.c_longlong), + ("lda", ctypes.c_longlong), + ("ldb", ctypes.c_longlong), + ("ldc", ctypes.c_longlong), + ("ldd", ctypes.c_longlong), + ("ptr_gather_A_indices", ctypes.c_void_p), + ("ptr_gather_B_indices", ctypes.c_void_p), + ("ptr_scatter_D_indices", ctypes.c_void_p) + ] + + return _GemmArguments, _EpilogueOutputOpParams + + +def get_gemm_arguments_streamk(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GemmArguments(ctypes.Structure): + _fields_ = [ + ("mode", ctypes.c_int), + ("problem_size", GemmCoord_), + ("batch_count", ctypes.c_int), + ("epilogue", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("batch_stride_A", ctypes.c_longlong), + ("batch_stride_B", ctypes.c_longlong), + ("batch_stride_C", ctypes.c_longlong), + ("batch_stride_D", ctypes.c_longlong), + ("stride_a", ctypes.c_longlong), + ("stride_b", ctypes.c_longlong), + ("stride_c", ctypes.c_longlong), + ("stride_d", ctypes.c_longlong), + ("lda", ctypes.c_longlong), + ("ldb", ctypes.c_longlong), + ("ldc", ctypes.c_longlong), + ("ldd", ctypes.c_longlong), + ("avail_sms", ctypes.c_int) + ] + + return _GemmArguments, _EpilogueOutputOpParams + + +########################################################################################### +# GEMM Grouped +########################################################################################### + + +def get_gemm_grouped_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _GEMMGroupedArguments(ctypes.Structure): + _fields_ = [ + ("problem_sizes", ctypes.c_void_p), + ("problem_count", ctypes.c_int), + ("threadblock_count", ctypes.c_int), + ("output_op", _EpilogueOutputOpParams), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("lda", ctypes.c_void_p), + ("ldb", ctypes.c_void_p), + ("ldc", ctypes.c_void_p), + ("ldd", ctypes.c_void_p), + ("host_problem_sizes", ctypes.c_void_p) + ] + + return _GEMMGroupedArguments, _EpilogueOutputOpParams + + +############################################################################################ +# Convolution2D +############################################################################################ + + +class Conv2DProblemSize_(ctypes.Structure): + _fields_ = [ + ("N", ctypes.c_int), + ("H", ctypes.c_int), + ("W", ctypes.c_int), + ("C", ctypes.c_int), + ("P", ctypes.c_int), + ("Q", ctypes.c_int), + ("K", ctypes.c_int), + ("R", ctypes.c_int), + ("S", ctypes.c_int), + ("pad_h", ctypes.c_int), + ("pad_w", ctypes.c_int), + ("stride_h", ctypes.c_int), + ("stride_w", ctypes.c_int), + ("dilation_h", ctypes.c_int), + ("dilation_w", ctypes.c_int), + ("mode", ctypes.c_int), # kCrossCorrelation: 0, kConvolution: 1 + ("split_k_slices", ctypes.c_int), + ("groups", ctypes.c_int) + ] + + def __init__(self, problem_size) -> None: + for field_name, _ in self._fields_: + setattr(self, field_name, getattr(problem_size, field_name)) + + +class Layout4D(ctypes.Structure): + _fields_ = [("stride", ctypes.c_int * 3)] + + def __init__(self, tensor_ref): + stride = tensor_ref.stride() + setattr(self, "stride", (stride.at(0), stride.at(1), stride.at(2))) + + +class TensorRef_(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("layout", Layout4D) + ] + + def __init__(self, tensor_ref): + setattr(self, "ptr", tensor_ref.data()) + setattr(self, "layout", Layout4D(tensor_ref.layout())) + + +class TensorRef2D_(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("stride", ctypes.c_int) + ] + + +def get_conv2d_arguments(epilogue_functor): + _EpilogueOutputOpParams = epilogue_functor.epilogue_type + + class _Conv2dArguments(ctypes.Structure): + _fields_ = [ + ("conv_kind", ctypes.c_int), + ("problem_size", Conv2DProblemSize_), + ("ptr_A", ctypes.c_void_p), + ("ptr_B", ctypes.c_void_p), + ("ptr_C", ctypes.c_void_p), + ("ptr_D", ctypes.c_void_p), + ("tensor_C_numel", ctypes.c_int), + ("output_op", _EpilogueOutputOpParams), + ("split_k_mode", ctypes.c_int) + ] + + return _Conv2dArguments, _EpilogueOutputOpParams + + +############################################################################################ +# Reduction +############################################################################################ + + +def get_reduction_params(epilogue_functor): + _EpilogueOutputParams = epilogue_functor.epilogue_type + + class _ReductionParams(ctypes.Structure): + _fields_ = [ + ("problem_size", MatrixCoord_), + ("partitions", ctypes.c_int), + ("partition_stride", ctypes.c_longlong), + ("workspace", TensorRef2D_), + ("destination", TensorRef2D_), + ("source", TensorRef2D_), + ("output_op", _EpilogueOutputParams), + ] + + return _ReductionParams, _EpilogueOutputParams + + +########################################################################################### +# Epilogue Visitor Type Factory +########################################################################################### + +class Empty(ctypes.Structure): + _fields_ = [] + + def __init__(self, *arg) -> None: + pass + +class EmptyByte(ctypes.Structure): + _fields_ = [ + ("byte", ctypes.c_byte) + ] + + def __init__(self, *arg) -> None: + pass + +class EBO: + def __init__(self, index: int, type) -> None: + self.index = index + self.type = type + + def __eq__(self, other) -> bool: + if isinstance(other, EBO): + return self.index == other.index and self.type == other.type + return False + + def __hash__(self) -> int: + return hash((self.index, self.type)) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self) -> str: + return f"<{self.index}, {self.type}>" + + +def tuple_factory_(input_tuple, dtype, constants=[0,1]): + """ + The factory function generating cute::Tuple with input tuple + :param input_tuple: the input tuple + :type input_tuple: tuple + :param dtype: the data type for non-constant values + :type dtype: str, "int32_t", "int", "int64_t" + :param constant: the values that will be treated as constants + :type constant: list[int] + + :return: ctype structure representing the cute::Tuple + :return: the empty base classes of the tuple + """ + + # The empty base classes of the current tuple + empty_bases = [] + # The first non empty base class + first_non_empty_base = None + # The ctype fields of the current tuple + ctype_fields = [] + + for idx, entry in enumerate(input_tuple): + # For nested tuples + if isinstance(entry, tuple): + sub_tuple_ctype, sub_empty_bases = tuple_factory_(entry, dtype, constants) + if ctypes.sizeof(sub_tuple_ctype) == 0: + # The empty tuple base class is also an empty EBO + empty_bases.append(EBO(idx, entry)) + else: + if first_non_empty_base is None: + first_non_empty_base = sub_empty_bases + ctype_fields.append((f"entry_{idx}", sub_tuple_ctype)) + else: + if entry in constants: + empty_bases.append(EBO(idx, entry)) + ctype_fields.append((f"entry_{idx}", Empty)) + else: + ctype_fields.append((f"entry_{idx}", dtype)) + if first_non_empty_base is None: + first_non_empty_base = [] + + # Create the ctype tuple + class TupleType(ctypes.Structure): + _fields_ = ctype_fields + + def __init__(self, args) -> None: + fields = self._fields_ + + assert len(fields) == len(args) + for field, arg in zip(fields, args): + name = field[0] + field_type = field[1] + setattr(self, name, field_type(arg)) + + return TupleType, empty_bases + +def tuple_factory(input_tuple, dtype: str, constants=[0,1]): + """ + The factory function generating cute::Tuple with input tuple + :param input_tuple: the input tuple + :type input_tuple: tuple + :param dtype: the data type for non-constant values + :type dtype: str, "int32_t", "int", "int64_t" + :param constant: the values that will be treated as constants + :type constant: list[int] + + :return: ctype structure representing the cute::Tuple + :return: the empty base classes of the tuple + """ + # Step 1: convert the dtype + if dtype == "int64_t": + dtype = ctypes.c_longlong + elif dtype in ["int", "int32_t"]: + dtype = ctypes.c_int32 + else: + raise NotImplementedError(f"Type {dtype} is not supported") + + tuple_type, _ = tuple_factory_(input_tuple, dtype, constants) + + if ctypes.sizeof(tuple_type) == 0: + return EmptyByte + return tuple_type + + +def visitor_factory(node_types, node_names): + """ + Creates the argument type of epilogue visitor type + + :param node_types: list of argument types under ctypes + :param node_names: list of argument names under str + + :return: tuple type in ctypes.Structure + """ + ctypes_field = [] + # Struct is used when number of nodes < 4 + # Because the Sm90VisitorImplBase has specification up to 4 nodes + # in `include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp` + if len(node_types) <= 4: + for idx, node_type in enumerate(node_types): + if ctypes.sizeof(node_type) == 0: + # Special case for empty struct + # 1 byte placeholder is used for correct alignment + ctypes_field.append((node_names[idx], ctypes.c_byte)) + else: + ctypes_field.append((node_names[idx], node_type)) + + class VisitorType(ctypes.Structure): + _fields_ = ctypes_field + + def __init__(self, kwargs) -> None: + for field in self._fields_: + fname, ftype = field + if ftype != ctypes.c_byte: + setattr(self, fname, ftype(kwargs)) + + # For cases with more than 4 nodes, tuple is used + else: + for idx, node_type in enumerate(node_types): + ctypes_field.append((node_names[idx], node_type)) + + class VisitorType(ctypes.Structure): + _fields_ = ctypes_field + + def __init__(self, kwargs) -> None: + for field in self._fields_: + fname, ftype = field + setattr(self, fname, ftype(kwargs)) + + return VisitorType diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..0b66ce8a2402a109e2da00613e7255760685855c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/compiler.py @@ -0,0 +1,462 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes +import json +import os +import sqlite3 +import subprocess +import tempfile + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +nvrtc = lazy_import("cuda.nvrtc") +from cutlass_library import SubstituteTemplate + +import cutlass_cppgen +from cutlass_cppgen import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger +from cutlass_cppgen.backend.gemm_operation import GemmOperationUniversal +from cutlass_cppgen.backend.library import ApiVersion +from cutlass_cppgen.backend.utils.device import device_cc + +IncludeTemplate = r"""#include "${include}" +""" + + +def compile_with_nvcc(cmd, source, error_file): + succeed = True + try: + subprocess.check_output(cmd, stderr=subprocess.STDOUT) + except subprocess.CalledProcessError as e: + error_message = e.output.decode() + with open(error_file, "w") as error_out: + error_log = "Compilation error for the following kernel: \n" + error_log += source + error_log += "\nError Message:\n" + error_log += error_message + error_out.write(error_log) + succeed = False + if not succeed: + # Print the error log to stdout if log level is set to warning or higher + # verbosity. Otherwise, simply point to the error log file. + logger.warning(error_log) + raise Exception(f"Invalid Kernel. See '{error_file}' for details.") + + +class CompilationOptions: + """ + Compilation options. + """ + + def __init__(self, flags, arch, include_paths=[]): + self.includes = [] + self.include_paths = include_paths + self.flags = flags + self.arch = arch + + def get_str(self): + opts = [] + for flag in self.flags: + opts.append(flag) + + for incl in self.include_paths: + opts.append(f"--include-path={incl}") + + arch_flag = f"-arch=sm_{self.arch}" + if self.arch in [90, 100, 101, 103, 120, 121] and int(cutlass_cppgen.nvcc_version().split('.')[0]) >= 12: + arch_flag += "a" + opts.append(arch_flag) + + return " ".join(opts) + + def get(self): + options = [] + + for flag in self.flags: + options.append(bytes(str.encode(flag))) + + for incl in self.include_paths: + options.append(bytes(str.encode(f" --include-path={incl}"))) + + arch_flag = f" -arch=sm_{self.arch}" + if self.arch in [90, 100, 101, 103, 120, 121]: + arch_flag += "a" + + options.append(bytes(str.encode(arch_flag))) + + return options + + +def convertToBinaryData(filename): + with open(filename, "rb") as file: + blobData = file.read() + return blobData + + +def CDLLBin(host_binary): + tempfile.tempdir = "./" + temp_so = tempfile.NamedTemporaryFile(prefix="host_func", suffix=".so", delete=True) + with open(temp_so.name, "wb") as file: + file.write(host_binary) + host_lib = ctypes.CDLL(temp_so.name) + return host_lib + + +class ArtifactManager: + """ + Artifact manager + """ + + def __init__(self) -> None: + connection = sqlite3.connect(CACHE_FILE) + cursor = connection.cursor() + # Create the table if it does not already exist + sqlite_create_table_query = """ + CREATE TABLE IF NOT EXISTS compiled_operations(op_key TEXT NOT NULL UNIQUE, + cubin BLOB NOT NULL, + hostbin BLOB NOT NULL, + op_name TEXT NOT NULL, + op_attrs TEXT NOT NULL) + """ + cursor.execute(sqlite_create_table_query) + connection.commit() + cursor.close() + + self._nvrtc_compile_options = ["-std=c++17", "-default-device"] + self._nvcc_compile_options = [ + "-std=c++17", + "--expt-relaxed-constexpr", + "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", + ] + self.nvcc() + self.compiled_cache_device = {} + self.compiled_cache_host = {} + + def nvrtc(self): + self.backend = "nvrtc" + self.default_compile_options = self._nvrtc_compile_options + + def nvcc(self): + self.backend = "nvcc" + self.default_compile_options = self._nvcc_compile_options + + def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): + connection = sqlite3.connect(CACHE_FILE) + cursor = connection.cursor() + sqlite_insert_blob_query = """ INSERT OR IGNORE INTO compiled_operations (op_key, cubin, hostbin, op_name, op_attrs) VALUES (?, ?, ?, ?, ?)""" + + hostbin = convertToBinaryData(hostfile) + + data_tuple = (op_key, cubin, hostbin, op_name, json.dumps(op_attrs)) + + cursor.execute(sqlite_insert_blob_query, data_tuple) + connection.commit() + cursor.close() + + def load_operation(self, op_key, extra_funcs): + connection = sqlite3.connect(CACHE_FILE) + cursor = connection.cursor() + sqlite_fetch_blob_query = """SELECT * from compiled_operations where op_key = ?""" + cursor.execute(sqlite_fetch_blob_query, (op_key,)) + record = cursor.fetchall() + if len(record) == 0: + return False + for row in record: + key, cubin_image, host_binary, operation_name, op_attr = row + op_attr = json.loads(op_attr) + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + + err, kernel = cuda.cuModuleGetFunction(module, bytes(str.encode(operation_name))) + self.compiled_cache_device[key] = kernel + + compiled_host_fns = {} + host_lib = CDLLBin(host_binary) + + func_name = operation_name + "_get_params" + func = getattr(host_lib, func_name) + func.restype = ctypes.POINTER(ctypes.c_char * op_attr[0]) + compiled_host_fns["get_args"] = func + + func_name = operation_name + "_shared_memory_size" + func = getattr(host_lib, func_name) + compiled_host_fns["shared_memory_capacity"] = func() + + for attr in op_attr: + if isinstance(attr, str): + func_name = operation_name + "_" + attr + func = getattr(host_lib, func_name) + + # Set the return type of the function + if attr in extra_funcs and extra_funcs[attr] != None: + func.restype = extra_funcs[attr] + + compiled_host_fns[attr] = func + + self.compiled_cache_host[key] = compiled_host_fns + return True + + def emit_compile_(self, operation_list, compilation_options, host_compilation_options): + """ + Compile a list of kernels and store them into database + """ + source_buffer_device = "" + source_buffer_host = "" + # 1. include + includes = [] + for operation in operation_list: + for incl in operation.emitter.includes: + if incl not in includes: + includes.append(incl) + + includes_host = ["builtin_types.h", "device_launch_parameters.h", "cstddef"] + includes + for incl in includes: + source_buffer_device += SubstituteTemplate( + IncludeTemplate, + {"include": incl}, + ) + + for incl in includes_host: + source_buffer_host += SubstituteTemplate( + IncludeTemplate, + {"include": incl}, + ) + + # 2. Operations + for operation in operation_list: + source_buffer_device += operation.emit() + source_buffer_host += operation.emit() + values = { + "operation_name": operation.name(), + "operation_suffix": operation.emitter.operation_suffix, + } + source_buffer_device += SubstituteTemplate( + operation.KernelTemplate, + values, + ) + source_buffer_host += SubstituteTemplate(operation.HostTemplate, values) + + if self.backend == "nvrtc": + # 3. compile + err, program = nvrtc.nvrtcCreateProgram( + str.encode(source_buffer_device), + bytes(str.encode("module.cu")), + 0, [], []) + + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + # Compile program + options = compilation_options.get() + + err, = nvrtc.nvrtcCompileProgram(program, len(options), options) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + error_string = "NVRTC Error: {}\n".format(err) + + # Get log from compilation + err, logSize = nvrtc.nvrtcGetProgramLogSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + log = b" " * logSize + err, = nvrtc.nvrtcGetProgramLog(program, log) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + raise RuntimeError(error_string + log.decode() + source_buffer_device) + + # Get data from compilation + err, dataSize = nvrtc.nvrtcGetCUBINSize(program) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + cubin_image = b" " * dataSize + (err,) = nvrtc.nvrtcGetCUBIN(program, cubin_image) + if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: + raise RuntimeError("NVRTC Error: {}".format(err)) + + else: # with nvcc backend + # emit code + tempfile.tempdir = "./" + temp_cu = tempfile.NamedTemporaryFile( + prefix="kernel", suffix=".cu", delete=True) + temp_cubin = tempfile.NamedTemporaryFile( + prefix="kernel", suffix=".cubin", delete=True) + with open(temp_cu.name, "w") as file: + file.write(source_buffer_device) + + # compile with nvcc + cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}" + values = { + "cuda_install_path": cuda_install_path(), + "options": compilation_options.get_str(), + "srcfile": temp_cu.name, + "tarfile": temp_cubin.name, + } + cmd = SubstituteTemplate(cmd_template, values) + compile_with_nvcc(cmd.split(" "), source_buffer_device, "./cutlass_python_compilation_device_error.txt") + + # load the cubin image + with open(temp_cubin.name, "rb") as file: + cubin_image = file.read() + + tempfile.tempdir = "./" + temp_src = tempfile.NamedTemporaryFile( + prefix="host_src", suffix=".cu", delete=True) + + # Write the host source + with open(temp_src.name, "w") as outfile: + outfile.write(source_buffer_host) + + temp_dst = tempfile.NamedTemporaryFile( + prefix="host_func", suffix=".so", delete=True) + + # Set up host compilation arguments + cmd = [] + cmd.append(f"{cuda_install_path()}/bin/nvcc") + cmd.extend(["-x", "cu", "-Xcompiler=-fpermissive", "-Xcompiler=-w", "-Xcompiler=-fPIC"]) + cmd.extend(host_compilation_options.get_str().split(" ")) + cmd.extend(["-shared", "-o", temp_dst.name, temp_src.name, "-lcudart", "-lcuda"]) + + # Comile and load the library + compile_with_nvcc( cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt") + host_lib = ctypes.CDLL(temp_dst.name) + + return cubin_image, host_lib, temp_dst + + def add_module(self, operations, compile_options=None, bypass_cache=False): + """ + Insert a new compiled device module + """ + include_paths = [ + cuda_install_path() + "/include", + CUTLASS_PATH + "/include", + CUTLASS_PATH + "/tools/util/include", + CUTLASS_PATH + "/python/cutlass/cpp/include", + ] + + cutlass_cppgen.initialize_cuda_context() + arch = device_cc() + + host_compile_options = CompilationOptions( + self._nvcc_compile_options, arch, include_paths) + if compile_options is None: + compile_options = CompilationOptions( + self.default_compile_options, arch, include_paths) + # save the cubin + operation_key = [] + operation_list = [] + for operation in operations: + # step 1: get kernel string as key + key = operation.rt_module.emit() + operation.procedural_name() + self.backend + # step 1: check if the operation is in cache + compiled_kernel = self.compiled_cache_device.get(key) + + if compiled_kernel is None and not bypass_cache: + hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {})) + if hit: + compiled_kernel = self.compiled_cache_device.get(key) + assert compiled_kernel is not None + if compiled_kernel is not None: + operation.rt_module.kernel = compiled_kernel + compiled_host_fns = self.compiled_cache_host.get(key) + assert compiled_host_fns is not None + for key in compiled_host_fns.keys(): + setattr(operation.rt_module, key, compiled_host_fns[key]) + operation.rt_module.initialize() + else: + operation_list.append(operation.rt_module) + operation_key.append(key) + + if len(operation_list) > 0: + cubin_image, host_lib, host_file = self.emit_compile_( + operation_list, compile_options, host_compile_options) + + err, module = cuda.cuModuleLoadData(cubin_image) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("Cuda Error: {}".format(err)) + + operation_name = [] + operation_attr = [] + for operation, key in zip(operation_list, operation_key): + # get device kernels + err, operation.kernel = cuda.cuModuleGetFunction( + module, + bytes(str.encode(operation.name())) + ) + operation_name.append(operation.name()) + self.compiled_cache_device[key] = operation.kernel + # get host functions + compiled_host_fns = {} + op_attr = [] + + # get param size + func_name = operation.name() + "_get_param_size" + func = getattr(host_lib, func_name) + param_size = func() + + func_name = operation.name() + "_get_params" + func = getattr(host_lib, func_name) + func.argtype = operation.argtype + func.restype = ctypes.POINTER(ctypes.c_char * param_size) + setattr(operation, "get_args", func) + compiled_host_fns["get_args"] = func + + # set shared memory size + func_name = operation.name() + "_shared_memory_size" + func = getattr(host_lib, func_name) + setattr(operation, "shared_memory_capacity", func()) + compiled_host_fns["shared_memory_capacity"] = func() + # set the maximum dynamic shared size + operation.initialize() + + # get extra functions + op_attr.append(param_size) + + if hasattr(operation, "extra_funcs"): + for suffix, ret_type in operation.extra_funcs.items(): + func_name = operation.name() + "_" + suffix + func = getattr(host_lib, func_name) + if ret_type is not None: + func.restype = ret_type + setattr(operation, suffix, func) + compiled_host_fns[suffix] = func + op_attr.append(suffix) + + operation_attr.append(op_attr) + self.compiled_cache_host[key] = compiled_host_fns + + for (key, operation_name, operation_attr,) in zip(operation_key, operation_name, operation_attr): + self.insert_operation( + key, cubin_image, host_file.name, operation_name, operation_attr) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..03679c434e1a63e9d1f9f2d1571dacedcf6e1470 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/conv2d_operation.py @@ -0,0 +1,700 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +import ctypes +from typing import Union + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_library import SubstituteTemplate +import numpy as np + +from cutlass_library import ( + ConvKindNames, + ConvKindTag, + DataTypeNames, + DataTypeSize, + DataTypeTag, + IteratorAlgorithmNames, + IteratorAlgorithmTag, + LayoutTag, + LayoutType, + MathOperation, + MathOperationTag, + OpcodeClass, + OpcodeClassNames, + OpcodeClassTag, + OperationKind, + ShortDataTypeNames, + ShortLayoutTypeNames, + SplitKMode, + StrideSupport, + StrideSupportTag, + SwizzlingFunctor, + SwizzlingFunctorTag, + get_complex_from_real, +) + +from cutlass_cppgen.backend.arguments import ArgumentBase +from cutlass_cppgen.backend.c_types import dim3_, get_conv2d_arguments +from cutlass_cppgen.backend.library import ( + EmissionType, + TensorDescription, + TileDescription, +) +from cutlass_cppgen.backend.memory_manager import device_mem_alloc +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.backend.utils.device import to_device_ptr +from cutlass_cppgen.shape import GemmCoord + + +class Conv2dArguments(ArgumentBase): + """ + Argument wrapper for Conv2d. It encodes problem information and + user-provide tensors into the kernel's argument. + + :param operation: the Conv2d operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.Conv2dOperation` + :param problem_size: the Conv2d problem size + :type problem_size: :class:`cutlass_cppgen.shape.Conv2dProblemSize` + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + :param split_k_mode: conv2d split K mode, defaults to cutlass_library.library.SplitKMode.Serial + :type split_k_mode: cutlass_library.library.SplitKMode, optional + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + """ + + def __init__(self, operation, problem_size, A, B, C, D, + split_k_mode=SplitKMode.Serial, **kwargs, ) -> None: + self.operation = operation + self.conv_kind = operation.conv_kind + self.layout_A = operation.A.layout + self.layout_B = operation.B.layout + self.layout_C = operation.C.layout + + self.element_A = operation.A.element + self.element_B = operation.B.element + self.element_C = operation.C.element + + if self.layout_C == LayoutType.TensorNC32HW32: + raise Exception("Layout type TensorNC32HW32 is not currently supported") + + super().__init__(A, B, C, D, **kwargs) + + if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1: + self.split_k_mode = split_k_mode + self.split_k_slices = kwargs["split_k_slices"] + else: + self.split_k_mode = SplitKMode.Serial + self.split_k_slices = 1 + + if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel: + self.output_op = kwargs["output_op"] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) + + self.problem_size = problem_size + self.problem_size.split_k_slices = self.split_k_slices + + self.initialize() + + def get_arguments(self): + tc_numel = -1 + if hasattr(self, "tensor_c_numel"): + tc_numel = self.tensor_c_numel + + self.c_arguments = self.operation.argument_type( + int(self.conv_kind), + self.problem_size.ctype, + int(to_device_ptr(self.ptr_A)), + int(to_device_ptr(self.ptr_B)), + int(to_device_ptr(self.ptr_C)), + int(to_device_ptr(self.ptr_D)), + tc_numel, + self.output_op, + int(self.split_k_mode) + ) + + def initialize(self): + self.launch_config = self.operation.rt_module.plan(self) + + self.get_arguments() + + # Allocate and initialize device workspace + device_workspace_size = self.operation.rt_module.get_workspace_size(self.c_arguments) + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + self.semaphore = 0 + if workspace_ptr is not None and self.split_k_mode == SplitKMode.Parallel: + self.ptr_D = workspace_ptr + # Reset arguments now that ptr_D has been updated + self.get_arguments() + elif workspace_ptr is not None and self.split_k_mode == SplitKMode.Serial: + self.semaphore = workspace_ptr + + params_ = self.operation.rt_module.get_args( + self.c_arguments, ctypes.c_void_p(int(self.semaphore))) + self.host_workspace = bytearray(params_.contents) + self.device_workspace = None + + def sync(self): + """ + Synchronize the arguments. If the input tensor is in host, + copy it from device to host. + """ + return super().sync() + + +class Conv2dRT(ExecutableOperation): + """ + Conv2dRT manages the CUTLASS runtime components + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + """ + + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + using ElementA = typename ${operation_name}_base::ElementA; + using ElementB = typename ${operation_name}_base::ElementB; + using ElementC = typename ${operation_name}_base::ElementC; + using LayoutA = typename ${operation_name}_base::LayoutA; + using LayoutB = typename ${operation_name}_base::LayoutB; + using LayoutC = typename ${operation_name}_base::LayoutC; + using EpilogueOutputOp = typename ${operation_name}_base::EpilogueOutputOp; + + struct ${operation_name}_TemporaryArgs { + int conv_kind; + cutlass::conv::Conv2dProblemSize problem_size; + ElementA* ptr_A; + ElementB* ptr_B; + ElementC* ptr_C; + ElementC* ptr_D; + int tensor_c_numel; + typename EpilogueOutputOp::Params epilogue_params; + int split_k_mode; + }; + + typename ${operation_name}${operation_suffix}::Arguments + construct_arguments(${operation_name}_TemporaryArgs args) { + cutlass::conv::Operator conv_operator = static_cast(args.conv_kind); + auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size); + auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size); + auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size); + auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size); + + auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3); + if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) { + // C is interpreted as bias + tc_C = {0, 0, 0, 0}; + } + + cutlass::TensorRef tref_A(args.ptr_A, LayoutA::packed(tc_A)); + cutlass::TensorRef tref_B(args.ptr_B, LayoutB::packed(tc_B)); + cutlass::TensorRef tref_C(args.ptr_C, LayoutC::packed(tc_C)); + cutlass::TensorRef tref_D(args.ptr_D, LayoutC::packed(tc_D)); + + return { + args.problem_size, + tref_A, + tref_B, + tref_C, + tref_D, + args.epilogue_params, + static_cast(args.split_k_mode) + }; + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_TemporaryArgs args, int *semaphore=nullptr) { + auto arguments = construct_arguments(args); + typename ${operation_name}${operation_suffix}::Params* params; + params = new ${operation_name}${operation_suffix}::Params(arguments, semaphore); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + dim3 ${operation_name}_get_grid_shape( + int conv_kind, + cutlass::conv::Conv2dProblemSize problem_size, + cutlass::gemm::GemmCoord tile_size, + int split_k_slices + ) { + + using Swizzle = typename ${operation_name}_base::ThreadblockSwizzle; + auto tiled_shape = Swizzle::get_tiled_shape( + static_cast(conv_kind), + problem_size, + tile_size, + split_k_slices); + + return Swizzle::get_grid_shape(tiled_shape); + } + + size_t ${operation_name}_get_workspace_size(${operation_name}_TemporaryArgs args) { + auto arguments = construct_arguments(args); + + // Temporarily define device::-level Conv2d so that we can call get_workspace_size + using DeviceConv = cutlass::conv::device::ImplicitGemmConvolution<${operation_name}_base>; + return DeviceConv::get_workspace_size(arguments); + } +} + + """ + + def __init__(self, operation: "Conv2dOperation"): + super().__init__(operation) + self.extra_funcs = { + "get_grid_shape": dim3_, + "get_workspace_size": ctypes.c_uint64 + } + self.argument_type, self.epilogue_type = get_conv2d_arguments(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_void_p] + self.conv_kind = operation.conv_kind + + self.operation: Conv2dOperation = operation + + self.emitter = EmitConv2dInstance("_type") + + self.threads = operation.tile_description.num_threads + + self.swizzle_functor = operation.swizzling_functor + + def emit(self): + return self.emitter.emit(self.operation) + + def plan(self, arguments: Conv2dArguments): + tile_size = GemmCoord( + self.operation.tile_description.threadblock_shape[0], + self.operation.tile_description.threadblock_shape[1], + self.operation.tile_description.threadblock_shape[2], + ) + + grid = self.get_grid_shape( + int(self.conv_kind), + arguments.problem_size.ctype, + tile_size.ctype, + arguments.split_k_slices + ) + + return LaunchConfiguration( + [grid.x, grid.y, grid.z], [self.threads, 1, 1], + self.shared_memory_capacity) + + def initialize(self): + err, = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error: {err}") + + +class Conv2dOperation: + """ + CUTLASS Conv2d operation description. + + :param conv_kind: convolution operator + :type conv_kind: :class:`cutlass_library.library.ConvKind` + + :param iterator_algorithm: Selects among several implementation + variants trading off performance with simplicity + :type iterator_algorithm: :class:`cutlass_library.library.IteratorAlgorithm` + + :param arch: GPU compute capability (sm_xx) + :type arch: int + + :param tile_description: tile description + :type tile_description: :class:`cutlass_cppgen.backend.TileDescription` + + :param A: tensor A description + :type A: :class:`cutlass_cppgen.backend.TensorDescription` + + :param B: tensor B description + :type B: :class:`cutlass_cppgen.backend.TensorDescription` + + :param C: tensor C description + :type C: :class:`cutlass_cppgen.backend.TensorDescription` + + :param D: tensor D description + :type D: :class:`cutlass_cppgen.backend.TensorDescription` + + :param element_epilogue: element type for computation in epilogue \ + :type element_epilogue: cutlass_library.library.DataType + + :param stride_support: distinguish among partial specializations that \ + accelerate certain problems where convolution stride is unit \ + :type stride_support: :class:`cutlass_library.library.StrideSupport` + + :param epilogue_functor: convolution epilogue functor + :type epilogue_functor: :class:`EpilogueFunctor` + + :param swizzling_functor: threadblock swizzling functor + """ + def __init__( + self, + conv_kind, + iterator_algorithm, + arch: int, + tile_description: TileDescription, + A: TensorDescription, + B: TensorDescription, + C: TensorDescription, + stride_support, + epilogue_functor, + swizzling_functor=SwizzlingFunctor.Identity1, + emission_type=EmissionType.Kernel, + **kwargs + ): + self.operation_kind: OperationKind = OperationKind.Conv2d + self.arch: int = arch + self.tile_description: TileDescription = tile_description + self.conv_kind = conv_kind + self.A: TensorDescription = A + self.B: TensorDescription = B + self.C: TensorDescription = C + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + + self.emission_type = emission_type + + self.rt_module: Conv2dRT = Conv2dRT(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def run(self, arguments: Conv2dArguments) -> cuda.CUresult: + """ + Launch the cuda kernel with input arguments + + :param arguments: conv2d arguments + :type arguments: :class:`cutlass_cppgen.backend.Conv2dArguments` + """ + + # launch the kernel + err = self.rt_module.run( + arguments.host_workspace, + arguments.device_workspace, + arguments.launch_config, + arguments.stream + ) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {err}") + + return err + + # + # Get function name + # + + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + return self.configuration_name() + + def configuration_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + + opcode_class_name = OpcodeClassNames[ + self.tile_description.math_instruction.opcode_class + ] + + threadblock = "%dx%d_%dx%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + self.tile_description.stages, + ) + + if self.stride_support == StrideSupport.Unity: + configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_align${alignment}" + else: + configuration_name = "cutlass_sm${arch}_${opcode_class}_${extended_name}_${threadblock}_${layout}_align${alignment}" + + return SubstituteTemplate( + configuration_name, + { + "arch": str(self.arch), + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment + }, + ) + + def extended_name(self): + """Append data types if they differ from compute type.""" + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }) + + return extended_name + + def layout_name(self): + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + def core_name(self): + """The basic operation kind is prefixed with a letter indicating the accumulation type.""" + + intermediate_type = "" + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%dx%dx%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.accumulator_type(): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = "" + + return "%s%s%s%s_%s" % ( + ShortDataTypeNames[self.accumulator_type()], + inst_shape, + intermediate_type, + ConvKindNames[self.conv_kind], + IteratorAlgorithmNames[self.iterator_algorithm] + ) + + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + def device_op(self): + """ + Returns a new Conv2dOperation object that is constructed with emission type + ``EmissionType.Device``. + + :return: operation ready for device-level code emission + :rtype: Conv2dOperation + """ + return Conv2dOperation( + self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description, + self.A, self.B, self.C, self.stride_support, self.epilogue_functor, self.swizzling_functor, + emission_type=EmissionType.Device) + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + + +class EmitConv2dInstance: + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/conv/kernel/default_conv2d_fprop.h", + "cutlass/conv/kernel/default_conv2d_dgrad.h", + "cutlass/conv/kernel/default_conv2d_wgrad.h", + "cutlass/conv/device/implicit_gemm_convolution.h" + ] + self.template = """ +// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" +using ${operation_name}_base = +typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} +>::Kernel; + +struct ${operation_name}${operation_suffix}: + public ${operation_name}_base { }; + +""" + + self.template_device = """ +// Conv2d operation ${operation_name} + +using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} +>::Kernel; + +using DeviceKernel = + typename cutlass::conv::device::ImplicitGemmConvolution; +""" + + def emit(self, operation): + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / + operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min( + operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "conv_kind": ConvKindTag[operation.conv_kind], + "conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "epilogue_vector_length": str(epilogue_vector_length), + "epilogue_functor": operation.epilogue_functor.emit(), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm], + "iterator_algorithm_name": IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + "stride_support": StrideSupportTag[operation.stride_support], + "math_operator": "cutlass::arch::OpMultiplyAddComplex" if operation.is_complex() else MathOperationTag[operation.tile_description.math_instruction.math_operation], + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + } + + if operation.emission_type == EmissionType.Kernel: + conv2d_template = self.template + else: + conv2d_template = self.template_device + + return SubstituteTemplate(conv2d_template, values) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..49ad79c9c8ecc9cad6067a3d9543b2625344848b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/epilogue.py @@ -0,0 +1,541 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes + +from cutlass_library import SubstituteTemplate +import numpy as np + +from cutlass_library import DataType, DataTypeTag +from cutlass_cppgen.backend.c_types import MatrixCoord_, tuple_factory +from cutlass_cppgen.backend.frontend import NumpyFrontend +from cutlass_cppgen.backend.library import ActivationOp, ActivationOpTag +from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor + +dtype2ctype = { + DataType.f16: ctypes.c_uint16, + DataType.bf16: ctypes.c_uint16, + DataType.f32: ctypes.c_float, + DataType.f64: ctypes.c_double, + DataType.s8: ctypes.c_int8, + DataType.s32: ctypes.c_int32 +} + +if is_torch_available(): + import torch + import torch.nn.functional as F + + +def get_scalar(value): + """ + Returns a scalar value from a container (e.g., np.ndarray) + """ + if is_numpy_tensor(value): + if value.size != 1: + raise Exception("Scalars used in epilogue must be of size 1") + return value.reshape(-1)[0] + elif is_torch_tensor(value): + if value.size != 1: + raise Exception("Scalars used in epilogue must be of size 1") + return value.reshape(-1)[0] + else: + return value + + +def to_ctype_value(value, dtype): + """ + Converts ``value`` to the corresponding storage needed for the ctype that + will store ``value``. + """ + scalar = get_scalar(value) + if dtype == DataType.f16: + # Convert f16 value into an integer + return int.from_bytes(np.float16(scalar).tobytes(), "little") + else: + return scalar + + +################################################################################################# +# +# Epilogue Functors +# +################################################################################################# + + +class EpilogueFunctorBase: + """ + Base class for thread-level epilogue functors + """ + + def __init__(self) -> None: + pass + + def emit(self, tag, template_argument): + template = """${tag}<${arguments}>""" + arguments = "" + for idx, arg in enumerate(template_argument): + arguments += arg + if idx < len(template_argument) - 1: + arguments += ", " + values = { + "tag": tag, + "arguments": arguments, + } + + return SubstituteTemplate(template, values) + + +class LinearCombination(EpilogueFunctorBase): + """ + Apply a linear combination operator to an array of elements + D = alpha * accumulator + beta * source + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + + tag = "cutlass::epilogue::thread::LinearCombination" + + def __init__( + self, element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + super().__init__() + + if element_accumulator is None: + element_accumulator = element_output + if element_epilogue is None: + element_epilogue = element_output + + self.element_output = element_output + self.element_accumulator = element_accumulator + self.element_epilogue = element_epilogue + self.epilogue_vector_length = epilogue_vector_length + + self.template_arguments = [ + DataTypeTag[element_output], + str(epilogue_vector_length), + DataTypeTag[element_accumulator], + DataTypeTag[element_epilogue], + ] + + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParamsEVT(ctypes.Structure): + """ + Epilogue params when using the default linear combination of EVT, which + does not currently use {alpha,beta}_ptr_array + """ + + stride_type = tuple_factory((0,0,1), "int64_t", [0]) + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ("dalpha", stride_type), + ("dbeta", stride_type), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ("alpha_ptr_array", ctypes.c_void_p), + ("beta_ptr_array", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + def to_evt_params(self) -> _EpilogueOutputOpParamsEVT: + return _EpilogueOutputOpParamsEVT(self.alpha, self.beta) + + self.epilogue_type = _EpilogueOutputOpParams + self.epilogue_type_evt = _EpilogueOutputOpParamsEVT + + def emit(self): + return super().emit(self.tag, self.template_arguments) + + +class LinearCombinationClamp(LinearCombination): + """ + Applies a linear combination operator to an array of elements then clamps + the output before converting to the output element type. + + D = alpha * accumulator + beta * source + uniform + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + + tag = "cutlass::epilogue::thread::LinearCombinationClamp" + + def __init__( + self, element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + # Base constructor + super().__init__( + element_output, + epilogue_vector_length, + element_accumulator, + element_epilogue, + ) + + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + self.epilogue_type = _EpilogueOutputOpParams + + +class FastLinearCombinationClamp(EpilogueFunctorBase): + """ + Applies a linear combination operator to an array of elements then clamps + the output before converting to the output element type. + + D = alpha * accumulator + beta * source + + Note: The below method only when problem_size_K <= 256 for signed int8 gemm + or problem_size_K <= 128 for unsigned int8 gemm. The default approach is + above. + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + """ + + tag = "cutlass::epilogue::thread::FastLinearCombinationClamp" + + def __init__(self, element_output, epilogue_vector_length, *args) -> None: + super().__init__() + + self.template_arguments = [ + DataTypeTag[element_output], str(epilogue_vector_length) + ] + + self.element_accumulator = DataType.s32 + self.element_epilogue = DataType.f32 + + # get epilogue output op + c_element_epilogue = dtype2ctype[self.element_epilogue] + element_epilogue = self.element_epilogue + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + self.epilogue_type = _EpilogueOutputOpParams + + def emit(self): + return super().emit(self.tag, self.template_arguments) + + +class LinearCombinationGeneric(LinearCombination): + """ + Applies a linear combination operator followed by an activation function + to an array of elements. + + D = activation(alpha * accumulator + beta * source) + + :param activation_functor: input activation functor + + :param element_output: data type used to load and store tensors + + :param epilogue_vector_length: number of elements computed per operation. + Usually it is 128/sizeof_bits_v, but we use 64 and 32 sometimes + when there are not enough data to store + + :param element_accumulator: Accumulator data type + + :param element_epilogue: data type used to compute linear combination + """ + + tag = "cutlass::epilogue::thread::LinearCombinationGeneric" + + def __init__( + self, activation_functor, + element_output, epilogue_vector_length, + element_accumulator=None, element_epilogue=None) -> None: + super().__init__( + element_output, + epilogue_vector_length, + element_accumulator, + element_epilogue, + ) + + self.template_arguments = [ + activation_functor.emit()] + self.template_arguments + + self.activation_functor = activation_functor + self.element_epilogue = element_epilogue + + # get epilogue output op + self.epilogue_type = self.activation_functor.epilogue_output_op(self.element_epilogue) + + +class ActivationFunctor: + """ + Base class for frequently used activation functions + """ + + @staticmethod + def numpy(x: np.ndarray): + raise NotImplementedError() + + @classmethod + def emit(cls): + return ActivationOpTag[cls.binding_type] + + @staticmethod + def epilogue_output_op(element_epilogue): + c_element_epilogue = dtype2ctype[element_epilogue] + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ] + + def __init__(self, alpha, beta, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + + return _EpilogueOutputOpParams + +class ActivationMeta(type): + @classmethod + def __call__(cls, x, *args): + if is_numpy_tensor(x): + return cls.numpy(x, *args) + elif is_torch_tensor(x): + return cls.torch(x, *args) + else: + raise NotImplementedError("Unsupported tensor type") + + @classmethod + def numpy(cls, *args): + raise NotImplementedError(f"Numpy reference for {cls.__name__[:-4]} is not implemented.") + + @classmethod + def torch(cls, *args): + raise NotImplementedError(f"PyTorch reference for {cls.__name__[:-4]} is not implemented.") + +############################################################################## +# identity operator +class identityMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return x + + @classmethod + def torch(cls, x): + return x + +class identity(ActivationFunctor, metaclass=identityMeta): + binding_type = ActivationOp.Identity + + +############################################################################## +# ReLu operator +class reluMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return np.where(x > 0, x, 0) + + @classmethod + def torch(cls, x): + return F.relu(x) + +class relu(ActivationFunctor, metaclass=reluMeta): + binding_type = ActivationOp.ReLU + + +############################################################################## +# Leaky ReLu operator +class leakyReLUMeta(ActivationMeta): + @classmethod + def numpy(cls, x, leaky_alpha): + return np.maximum(x, 0) + np.minimum(x, 0) * leaky_alpha + + @classmethod + def torch(cls, x, leaky_alpha): + return F.leaky_relu(x, leaky_alpha) + +class leaky_relu(ActivationFunctor, metaclass=leakyReLUMeta): + binding_type = ActivationOp.LeakyReLU + + @staticmethod + def epilogue_output_op(element_epilogue): + c_element_epilogue = dtype2ctype[element_epilogue] + + class _EpilogueOutputOpParams(ctypes.Structure): + _fields_ = [ + ("alpha", c_element_epilogue), + ("beta", c_element_epilogue), + ("alpha_ptr", ctypes.c_void_p), + ("beta_ptr", ctypes.c_void_p), + ("leaky_alpha", c_element_epilogue) + ] + + def __init__(self, alpha, beta, leaky_alpha=0.2, *args) -> None: + self.alpha = to_ctype_value(alpha, element_epilogue) + self.beta = to_ctype_value(beta, element_epilogue) + self.alpha_ptr = 0 + self.beta_ptr = 0 + self.leaky_alpha = to_ctype_value(leaky_alpha, element_epilogue) + + return _EpilogueOutputOpParams + + +############################################################################## +# Tanh operator +class tanhMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return np.tanh(x) + + @classmethod + def torch(cls, x): + return torch.tanh(x) + +class tanh(ActivationFunctor, metaclass=tanhMeta): + binding_type = ActivationOp.Tanh + + +############################################################################## +# Sigmoid operator +class sigmoidMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return 1.0 / (1.0 + np.exp(-x)) + + @classmethod + def torch(cls, x): + return F.sigmoid(x) + +class sigmoid(ActivationFunctor, metaclass=sigmoidMeta): + binding_type = ActivationOp.Sigmoid + + +############################################################################## +# SiLu operator +class siluMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + return x * sigmoidMeta.numpy() + + @classmethod + def silu(cls, x): + return F.silu(x) + + +class silu(ActivationFunctor, metaclass=siluMeta): + binding_type = ActivationOp.SiLU + + +############################################################################## +# Hardswish operator +class hardswishMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + relu6 = np.minimum(np.maximum(x + 3.0, 0), 6.0) + return x * relu6 / 6.0 + + @classmethod + def torch(cls, x): + return F.hardswish(x) + + +class hardswish(ActivationFunctor, metaclass=hardswishMeta): + binding_type = ActivationOp.HardSwish + + +############################################################################## +# GELU operator +class geluMeta(ActivationMeta): + @classmethod + def numpy(cls, x): + from scipy.special import erf + return 0.5 * x * (1 + erf(x / np.sqrt(2.0))) + + @classmethod + def torch(cls, x): + return F.gelu(x) + + +class gelu(ActivationFunctor, metaclass=geluMeta): + binding_type = ActivationOp.Gelu diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b61e983ab23bb5662d15e185184efa227351446d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/__init__.py @@ -0,0 +1,34 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.epilogue import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..945dcf80e307eb870f31722822f959da03e6c421 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/__init__.py @@ -0,0 +1,38 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.backend.sm80_emitter import Sm80Emitter +import cutlass_cppgen.backend.evt.backend.sm80_nodes as sm80_nodes +from cutlass_cppgen.backend.evt.backend.sm90_emitter import Sm90Emitter +import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes +from cutlass_cppgen.backend.evt.backend.sm100_emitter import Sm100Emitter +import cutlass_cppgen.backend.evt.backend.sm100_nodes as sm100_nodes diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py new file mode 100644 index 0000000000000000000000000000000000000000..72a7d8c04db5c8df2595fab8befaa07bf238c2f2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/emitter_base.py @@ -0,0 +1,159 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base class for Epilogue Visitor Emitter +""" + +from cutlass_library import DataTypeTag +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR + + +class FusionCallbacks: + def __init__(self, dag_ir: DAGIR, cc: int, emit_CD=True) -> None: + """ + Emit the EVT fusion callbacks + :param dag_ir: the DAG IR holding the epilogue visitor + :param cc: compute capability + :param emit_CD: whether to emit nodes C & D as a part of the fusion callbacks + For Sm90, set emit_CD=False, as Tensor C & D are hardcoded in the collective API + so that their shared memory can be explicitly reused + For Sm89, set emit_CD=True as they are treated as normal AuxLoad & AuxStore nodes. + """ + self.dag_ir = dag_ir + self.emit_CD = emit_CD + self.cc = cc + self.evt_cc = 90 if cc >= 90 else cc + if self.cc < 90: + self.namespace = "threadblock" + else: + self.namespace = "fusion" + + # + # Helper functions + # + + def get_visitor_name(self, node: str): + """ + Get the visitor name + """ + meta = self.dag_ir.get_node_meta(node) + if not isinstance(meta, TopoVisitorNode) and self.dag_ir.in_degree(node) > 0: + return f"EVT{meta.name_camel}" + else: + return meta.name_camel + + def emit(self): + node_metas = self.dag_ir.node_metas_topological_order() + epilogue_str = "" + # Step 1: emit individual node type decl + # emit the EVT & DAG connector + for meta in node_metas: + if not meta.disabled: + epilogue_str += self.emit_node(meta) + if not self.emit_CD and meta.name == "D": + continue + if isinstance(meta, TopoVisitorNode): + epilogue_str += self.emit_dag(meta) + else: + epilogue_str += self.emit_evt(meta) + + # Step 2: post-processing & get callback name + if not self.emit_CD: + if not self.dag_ir.has_node("C"): + epilogue_str += "using ElementC = void;\nusing StrideC = StrideD;\n" + output_node = self.dag_ir.get_all_inputs("D")[0] + # The callback is the src of node D + callback_name = self.get_visitor_name(output_node) + else: + # The callback is the last node in the topological order + callback_name = self.get_visitor_name(node_metas[-1].name) + return epilogue_str, callback_name + + def emit_evt(self, node): + if self.dag_ir.in_degree(node.name) == 0: + return "" + + evt_tmp = f""" +using EVT{node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}EVT< + {node.name_camel}, +""" + sorted_children = self.dag_ir.get_all_inputs(node.name) + evt_node_strs = [f" {self.get_visitor_name(child_name)}" for child_name in sorted_children] + evt_tmp += ",\n".join(evt_node_strs) + ">;\n" + + return evt_tmp + + def emit_dag(self, node): + subgraph = node.subgraph + subgraph_nodes = subgraph.nodes_topological_order() + # Emit the Edge Tuple + edge_tuples = "cute::tuple<\n" + for n in subgraph_nodes[:-1]: + in_edges = subgraph.in_edges(n) + edge_weights = [subgraph.get_edge_weight(edge[0], edge[1]) for edge in in_edges] + sorted_children = [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))] + edge_tuple = " cute::seq<" + edge_str = [str(subgraph_nodes.index(child)) for child in sorted_children] + edge_tuple += ", ".join(edge_str) + ">,\n" + + edge_tuples += edge_tuple + edge_tuples += " >" + + # Emit the node list + dag_nodes = "" + dag_node_strs = [] + for n in subgraph_nodes[:-1]: + n_meta = subgraph.get_node_meta(n) + if n_meta.disabled: + dag_node_strs.append(f" {self.get_visitor_name(n)}") + else: + dag_node_strs.append(f" {n_meta.name_camel}") + dag_nodes = ",\n".join(dag_node_strs) + + return f""" +using {node.name_camel} = cutlass::epilogue::{self.namespace}::Sm{self.evt_cc}TopologicalVisitor< + {DataTypeTag[node.subgraph.element_compute]}, + {edge_tuples}, +{dag_nodes} +>; +""" + + def emit_node(self, node): + if isinstance(node, TopoVisitorNode): + emission = "" + for node in node.subgraph.node_metas_topological_order(): + if not node.disabled: + emission += self.emit_node(node) + return emission + else: + return node.underlying_impl.type_decl diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..db521e5279c57734a8e408938dc6ea95a608c6d8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py @@ -0,0 +1,116 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm100 Epilogue Visitor +""" + +from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag +from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape +from cutlass_cppgen.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend.evt.ir.node import TupleEmitter + + +class Sm100CollectiveEpilogue: + def __init__(self, tile_description, + kernel_schedule, + epilogue_schedule, + element_accumulator, + element_d, + fusion_callbacks) -> None: + + self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule) + self.element_accumulator = element_accumulator + if fusion_callbacks.dag_ir.has_node("C"): + self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element + else: + self.element_c = DataType.void + self.element_d = element_d + self.schedule = epilogue_schedule + self.fusion_callbacks = fusion_callbacks + self.opclass = tile_description.math_instruction.opcode_class + + @property + def CtaTileMNK(self) -> str: + """ + The threadblock shape + """ + return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" + + @property + def EpilogueTileType(self) -> str: + """ + The epilogue tile type + """ + return "cutlass::epilogue::collective::EpilogueTileAuto" + + @property + def Schedule(self) -> str: + return EpilogueScheduleTag[self.schedule] + + def emit(self): + tuple_emitter = TupleEmitter("int64_t") + stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl + stride_C_str = stride_D_str + if self.fusion_callbacks.dag_ir.has_node("C"): + stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl + + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, f""" +using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor< + {OpcodeClassTag[self.opclass]}, + {self.CtaTileMNK}, {self.EpilogueTileType}, + {DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, + {self.Schedule}, {stride_C_str}, {stride_D_str}, + false /* IsPerColScaleSupported */, + false /* IsBlockScaleSupported */ +>; +{callback_decl} +""" + + +class Sm100Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False) + + self.collective_epilogue = Sm100CollectiveEpilogue( + tile_description=operation.tile_description, + kernel_schedule=operation.tile_description.kernel_schedule, + epilogue_schedule=operation.tile_description.epilogue_schedule, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_d=fusion_callbacks.dag_ir.get_node_meta("D").element, + fusion_callbacks=fusion_callbacks + ) + + def emit(self): + return self.collective_epilogue.emit() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..33e77b4c9f2efbef808f8551e4402f5a6761ea4a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from pycute import product + +from cutlass_library import DataTypeSize, DataTypeTag + +from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl +import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes + +from cutlass_cppgen.backend.library import FloatRoundStyleTag + + +Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl +Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl +Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl +Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl +Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl +Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl +Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl +Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl +Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl +Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl + + +class Sm100AuxLoadImpl(AuxLoadImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor;\n" + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) + + +class Sm100AuxStoreImpl(AuxStoreImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f""" +using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor< + EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} +>; +""" + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, + typename {self.descriptor}::CopyOpR2S +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..868453a7cf5049e5899bf6aef419485a1a5dbb43 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py @@ -0,0 +1,47 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm80 Epilogue Visitor +""" + +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks +from cutlass_cppgen.backend import GemmOperationUniversal + + +class Sm80Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + self.fusion_callbacks = FusionCallbacks(graph, cc=80) + + def emit(self): + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, callback_decl diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..b9fc561354a471f4f97600b27e4dbb21950a9e79 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py @@ -0,0 +1,258 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_library import DataTypeSize, DataTypeTag + +from cutlass_cppgen.backend.evt.ir import ( + # Load Node + AccumulatorImpl, + AuxLoadImpl, + ColumnBroadcastImpl, + LoadNode, + LoadSrcImpl, + RowBroadcastImpl, + ScalarBroadcastImpl, + # Compute Node + ComputeImpl, + # Store Node + AuxStoreImpl, + ColumnReductionImpl, + RowReductionImpl, + ScalarReductionImpl +) + +from cutlass_cppgen.backend.library import ( + FloatRoundStyleTag, + FunctionalOp, + op_tag, +) + + +class Sm80AccumulatorImpl(AccumulatorImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n""" + return self._type_decl + + +class Sm80AuxLoadImpl(AuxLoadImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad< + OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80LoadSrcImpl(Sm80AuxLoadImpl): + pass + + +class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl): + def __init__(self, node: LoadNode) -> None: + super().__init__(node) + self.broadcast_count = 1 + self.reduction_fn = FunctionalOp.Multiplies + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast< + {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)} +>; +""" + return self._type_decl + + +class Sm80RowBroadcastImpl(RowBroadcastImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, {DataTypeTag[self.element]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, {DataTypeTag[self.element]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80ComputeImpl(ComputeImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute< + {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]}, + {FloatRoundStyleTag[self.round_style]} +>; +""" + return self._type_decl + + +class Sm80AuxStoreImpl(AuxStoreImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80StoreDImpl(Sm80AuxStoreImpl): + pass + + +class Sm80ColumnReductionImpl(ColumnReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + OutputTileThreadMap, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80RowReductionImpl(RowReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + OutputTileThreadMap, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm80ScalarReductionImpl(ScalarReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + OutputTileThreadMap, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..3c058aa8f30a56d97ce3c3600f7c89189e7a15ad --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Emitter for Sm90 Epilogue Visitor +""" + +from cutlass_library import DataTypeTag, EpilogueScheduleTag +from cutlass_cppgen.backend import GemmOperationUniversal +from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks + + +class CollectiveEpilogue: + def __init__(self, tile_description, + schedule, + element_c, + element_d, + fusion_callbacks) -> None: + + self.cta_tile_mnk = tile_description.threadblock_shape + self.element_c = element_c + self.element_d = element_d + self.schedule = schedule + self.fusion_callbacks = fusion_callbacks + + @property + def CtaTileMNK(self) -> str: + """ + The threadblock shape + """ + return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>" + + @property + def EpilogueTileType(self) -> str: + """ + The epilogue tile type + """ + return "cutlass::epilogue::collective::EpilogueTileAuto" + + @property + def Schedule(self) -> str: + return EpilogueScheduleTag[self.schedule] + + def emit(self): + callback_decl, callback_name = self.fusion_callbacks.emit() + return callback_name, f""" +using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + {self.CtaTileMNK}, {self.EpilogueTileType}, + {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]}, + {self.Schedule} +>; +{callback_decl} +""" + + +class Sm90Emitter: + def __init__(self, operation: GemmOperationUniversal, graph) -> None: + fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False) + + self.collective_epilogue = CollectiveEpilogue( + tile_description=operation.tile_description, + schedule=operation.tile_description.epilogue_schedule, + element_c=operation.C.element, + element_d=operation.C.element, + fusion_callbacks=fusion_callbacks + ) + + def emit(self): + return self.collective_epilogue.emit() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..43601a424e3ecb175837fb31389436c1470d9c0b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py @@ -0,0 +1,329 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from pycute import product + +from cutlass_library import DataTypeSize, DataTypeTag +from cutlass_cppgen.backend.evt.ir import ( + # Load Node + AccumulatorImpl, + AuxLoadImpl, + ColumnBroadcastImpl, + LoadNode, + LoadSrcImpl, + RowBroadcastImpl, + ScalarBroadcastImpl, + # Compute Node + ComputeImpl, + ComputeNode, + # Store Node + AuxStoreImpl, + ColumnReductionImpl, + RowReductionImpl, + ScalarReductionImpl, + StoreNode, + StoreDImpl, +) +from cutlass_cppgen.backend.library import ( + FloatRoundStyleTag, + FunctionalOp, + op_tag, +) + + +class Sm90AccumulatorImpl(AccumulatorImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n""" + return self._type_decl + + +class Sm90LoadSrcImpl(LoadSrcImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using ElementC = {DataTypeTag[self.element]}; +using StrideC = {self.stride_mnl}; +using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>; +""" + return self._type_decl + + +class Sm90AuxLoadImpl(AuxLoadImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor;\n" + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128) + + +class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl): + def __init__(self, node: LoadNode) -> None: + super().__init__(node) + self.broadcast_count = 1 + self.reduction_fn = FunctionalOp.Multiplies + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast< + {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)} +>; +""" + return self._type_decl + + +class Sm90RowBroadcastImpl(RowBroadcastImpl): + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90ComputeImpl(ComputeImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute< + {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]}, + {FloatRoundStyleTag[self.round_style]} +>; +""" + return self._type_decl + + +class Sm90AuxStoreImpl(AuxStoreImpl): + + @property + def descriptor(self) -> str: + """ + Descriptor for Aux Load + """ + return f"{self.name_camel}Descriptor" + + def decl_descriptor(self) -> str: + """ + Declare the descriptor type + """ + return f""" +using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]} +>; +""" + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = self.decl_descriptor() + self._type_decl += f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore< + {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, + typename {self.descriptor}::CopyOpR2S +>; +""" + return self._type_decl + + def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles): + """ + Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d + """ + return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128) + + +class Sm90StoreDImpl(StoreDImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + return f""" +using ElementD = {DataTypeTag[self.element]}; +using StrideD = {self.stride_mnl}; +""" + + +class Sm90ColumnReductionImpl(ColumnReductionImpl): + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0, + typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90RowReductionImpl(RowReductionImpl): + + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */, + typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, + {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]}, + {self.stride_mnl} +>; +""" + return self._type_decl + + +class Sm90ScalarReductionImpl(ScalarReductionImpl): + + + @property + def type_decl(self): + """ + Return the string defining the type + """ + if self._type_decl is not None: + return self._type_decl + + self._type_decl = f""" +using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction< + {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, + {DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]}, + {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl} +>; +""" + return self._type_decl diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..da446e76d9ebd9de04950a89b2451480492147a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py @@ -0,0 +1,168 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Epilogue Visitor interface for compiling, and running visitor-based epilogue. +""" + +import ctypes + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_library import DataType +import numpy as np + +from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase +import cutlass_cppgen.backend.evt.backend +from cutlass_cppgen.backend.frontend import TensorFrontend +from cutlass_cppgen.utils.datatypes import is_numpy_tensor +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class EpilogueFunctorVisitor(EpilogueFunctorBase): + """ + Apply an epilogue functor described by the epilogue EVT + + :param cc: compute capability + :param visitor_frontend: user-provide visitor frontend + + """ + def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None: + # Type of Emitter based on CC + self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter") + + # Visitor Types + self.visitor = visitor + self.graph = visitor.dag_ir + + # Data types + self.element_epilogue = element_compute # element compute + self.element_output = self.graph.get_node_meta('D').underlying_impl.element + + # Epilogue Thread Type + epilogue_thread_type = self.visitor.epilogue_thread_type + if cc_map[cc] in [90, 100]: + self.arg_c_type = self.visitor.arg_c_type + self.arg_d_type = self.visitor.arg_d_type + output_names = self.visitor.return_names + reduction_names = self.visitor.reduction_names + + # Epilogue stages specialized for sm80 kernel + if cc == 80: + if hasattr(self.visitor, "epilogue_stages"): + self.epilogue_stages = self.visitor.epilogue_stages + assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue" + + # Epilogue Argument Type + class _Arguments(ctypes.Structure): + """ + Concepts: + class _EpilogueArguments(ctypes.Structure): + _fields_ = [ + ("epilogue", _Arguments), <- this class + ("ptr_C", ctypes.c_void_p), + ("stride_C", StrideBatched_), + ("ptr_D", ctypes.c_void_p), + ("stride_D", StrideBatched_) + ] + """ + _fields_ = [ + ("output_op", epilogue_thread_type) + ] + + def __init__(self, kwargs: dict) -> None: + # The user-input kwargs is a dict of (name: tensors) + # We first convert all of them to device pointers + ptr_kwargs = {} + for key in kwargs.keys(): + is_output = key in output_names and key not in reduction_names + ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output) + # Initialize the thread arguments + self.output_op = epilogue_thread_type(ptr_kwargs) + + def get_tensor_ptr(self, tensor_name, kwargs, is_output=False): + """ + Helper function for extracting device pointer + """ + # Skip the special tensors + if cc in [90, 100]: + if tensor_name in ["C", "D"]: + return 0 + if tensor_name not in kwargs.keys(): + raise ValueError(f"Tensor {tensor_name} is not provided.") + tensor = kwargs[tensor_name] + + # For float scalar constant, directly return the value + if isinstance(tensor, float): + return tensor + + # The tensor frontend returns a device buffer for np.ndarray + # and device ptr for other frontends + buffer_or_ptr = TensorFrontend.argument(tensor, is_output) + if is_numpy_tensor(tensor): + # Remember the host tensor for later synchronization + setattr(self, f"{tensor_name}_buffer", buffer_or_ptr) + setattr(self, f"{tensor_name}_host", tensor) + return int(buffer_or_ptr.ptr) + else: + return int(buffer_or_ptr) + + def sync(self): + """ + Synchronize the results from device to host + """ + for name in output_names: + if hasattr(self, f"{name}_host"): + host_tensor = getattr(self, f"{name}_host") + tensor_ptr = getattr(self, f"{name}_buffer").ptr + (err,) = cuda.cuMemcpyDtoH( + host_tensor, + tensor_ptr, + host_tensor.size * host_tensor.itemsize, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + self.epilogue_type = _Arguments + + def emit(self, operation): + """ + Emit the C++ code + """ + emitter = self.emit_cls(operation, self.graph) + return emitter.emit() + + def get_smem_size(self, tile_description): + """ + Get the shared memory size in bytes + """ + return self.visitor.get_smem_size(tile_description) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2323278ed232adea205e41b901c62a268e56976 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py @@ -0,0 +1,33 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py new file mode 100644 index 0000000000000000000000000000000000000000..213aafdbe3f922f22186e37ac9f2eefea74e71ce --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py @@ -0,0 +1,272 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base class for Python EVT Frontend +""" + +from typing import Union + +from cutlass_library import DataType +from cutlass_cppgen.backend.evt.ir import ( + ComputeNode, + DAGIR, + LayoutNode, + LoadNode, + StoreNode, +) +from cutlass_cppgen.backend.evt.passes import ( + EVTGraphDrawer, + EVTPassManager, + GetSmemSize, + PassDAG2Tree, + PassGetArgumentType, + PassGetImpl, + PassFixElementD, + PassLayoutManipulateElimination, + PassPreprocessRed, + PassShapeTypePropagation, +) +from cutlass_cppgen.backend.evt.passes.util import cc_map +from cutlass_cppgen.backend.utils import device_cc +from cutlass_cppgen.epilogue.evt_ops import permute, reshape +from cutlass_cppgen.utils.datatypes import library_type + + +class EVTFrontendBase: + layout_fns = { + "permute": permute, + "reshape": reshape + } + + def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None: + self.cc = cc + self.element_compute = library_type(element_compute) + self.dag_ir = DAGIR(self.cc, self.element_compute) + self.compute_cnt = 0 + self.layout_cnt = 0 + self.imm_cnt = 0 + + self.pass_manager = EVTPassManager( + self.dag_ir, + [ + PassPreprocessRed, + PassGetArgumentType, + PassShapeTypePropagation, + PassLayoutManipulateElimination, + PassGetImpl, + PassDAG2Tree, + PassFixElementD + ] + additional_passes) + + if self.cc == 80: + self._epilogue_stages = 1 + else: + self._epilogue_stages = None + + @property + def epilogue_stages(self): + return self._epilogue_stages + + @epilogue_stages.setter + def epilogue_stages(self, stages): + self._epilogue_stages = stages + + + def parse(self, *args, **kwargs): + raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class") + + def trace(self, *args, **kwargs): + # Parse the input + self.parse(*args, **kwargs) + + # Verify the DAG IR to ensure that "D" is the output node with out_degree = 0 + if (self.cc >= 90): + if (self.dag_ir.out_degree("D") != 0): + raise RuntimeError( + f"On SM90 or higher, D is expected to be a output node with 0 users to " + f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}") + + # Run the passes + self.pass_manager() + # Set the epilogue type + self.epilogue_thread_type = self.dag_ir.epilogue_thread_type + if cc_map[self.cc] in [90, 100]: + self.arg_c_type = self.dag_ir.arg_c_type + self.arg_d_type = self.dag_ir.arg_d_type + self.reduction_names = self.dag_ir.reduction_names + + # + # Helper functions for DAG IR manipulation + # + + def add_node(self, node): + self.dag_ir.add_node(node) + + def add_edge(self, src, tgt, weight=0): + self.dag_ir.add_edge(src, tgt, weight=weight) + + def set_tensor(self, node_name, example): + """ + Add an example tensor to node {node_name} in the DAG IR + """ + meta = self.dag_ir.get_node_meta(node_name) + meta.tensor = {"tensor": example} + + def set_store_tensor(self, node_name, example): + """ + Add an example tensor to node {node_name} in the DAG IR + """ + meta = self.dag_ir.get_node_meta(node_name) + meta.store_tensor = {"tensor": example} + + def mark_output(self, node_name): + """ + Mark a store node as output + """ + meta = self.dag_ir.get_node_meta(node_name) + if not isinstance(meta, StoreNode): + raise ValueError( + f"Only StoreNodes can be marked as output. " + f"Got {type(meta).__name__}: {node_name}") + meta.is_output = True + + # Add node with specific type + + def add_load_node(self, name, example): + """ + Add a Load node to DAG IR + :param name: name of the loaded variable + :type name: str + :param example: example input + :type example: np.ndarray|torch.Tensor|cupy.ndarray|float + """ + if name is None: + raise ValueError(f"Name is not provided.") + if example is None: + raise ValueError(f"Example input for {name} is not provided.") + load_node = LoadNode(name) + load_node.tensor = {"tensor": example} + # Special logics for accumulator + if name == "accum": + if load_node.tensor.rank == 2: + new_shape = tuple([1, ] + list(load_node.tensor.shape)) + load_node.tensor.broadcast(new_shape) + elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3: + raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.") + self.add_node(load_node) + + def add_imm(self, value: Union[float,int]): + """ + Add an immediate scalar value to DAG IR + :param value: the value of the immediate scalar + :type value: float + """ + try: + value = float(value) + except: + raise ValueError(f"{type(value).__name__} cannot be converted to float.") + + name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_') + self.imm_cnt += 1 + load_node = LoadNode(name) + load_node.tensor = {"tensor": value, "is_constant": True} + self.add_node(load_node) + return name + + def add_compute_node(self, op, name=None): + """ + Add a compute node. + :param op: the computation op + :param name: the node name (optional) + :type name: str + :return: the name of the compute node + """ + if name is None: + name = f"compute_{self.compute_cnt}" + self.compute_cnt += 1 + compute_node = ComputeNode( + name=name, fn=op, + element_output=self.element_compute, + element_compute=self.element_compute) + self.add_node(compute_node) + return compute_node.name + + def add_layout_node(self, op, kwargs, name=None): + """ + Add a layout node. + :param op: the layout op + :type op: evt_ops + :param name: the node name (optional) + :type name: str + :return: the name of the layout node + """ + if name is None: + name = f"layout_{self.layout_cnt}" + self.layout_cnt += 1 + layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs) + self.add_node(layout_node) + return layout_node.name + + def add_store_node(self, name): + store_node = StoreNode(name) + self.add_node(store_node) + + # + # Visualization The DAG IR + # + + def visualize(self, name="dag_ir"): + """ + Visualize the dag ir with svg file + :param name: the name of the graph + """ + drawer = EVTGraphDrawer(self.dag_ir, name) + try: + for name, graph in drawer.get_dot_graph(): + graph.write_svg(f"./{name}.svg") + except: + raise RuntimeError( + "'dot' is not found in path. GraphDrawer is disabled. " + "Please install it with 'sudo apt-get install graphviz'." + ) + + # + # Get shared memory size + # + + def get_smem_size(self, tile_description): + """ + Get the shared memory size of the epilogue + """ + smem_size = GetSmemSize(self.dag_ir)(tile_description) + return smem_size diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..8727b754cd2b9a557d45760cb0a24a43619a373f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py @@ -0,0 +1,194 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Python AST frontend that parses input into DAG IR +""" + +import ast +import inspect +import textwrap + +from cutlass_library import DataType + +import cutlass_cppgen +from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase +from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu +from cutlass_cppgen.backend.library import FunctionalOp + + +class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor): + def __init__(self, cc, element_compute=DataType.f32, **kwargs): + super().__init__(cc, element_compute, **kwargs) + # Flags + # If this state is True, visit_Constant returns values without creating imm node + self.no_imm = False + self.visiting_return = False + + def parse(self, example_inputs): + self.example_inputs = example_inputs + self.source = textwrap.dedent(inspect.getsource(self.__call__)) + self.ast = ast.parse(self.source) + self.visit(self.ast) + + # + # Helper functions + # + @staticmethod + def ast_op_to_bindings(op): + mapping = { + ast.Add: FunctionalOp.Plus, + ast.Sub: FunctionalOp.Minus, + ast.Mult: FunctionalOp.Multiplies, + ast.Div: FunctionalOp.Divides, + "maximum": FunctionalOp.Maximum, + "minimum": FunctionalOp.Minimum, + "identity": identity.binding_type, + "relu": relu.binding_type, + "tanh": tanh.binding_type, + "sigmoid": sigmoid.binding_type, + "silu": silu.binding_type, + "hardswish": hardswish.binding_type, + "gelu": gelu.binding_type, + "multiply_add": FunctionalOp.MultiplyAdd, + "sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd), + "max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum), + "exp": FunctionalOp.Exp + } + return mapping[op] + + # + # Visiting different node types + # + + def visit_FunctionDef(self, node: ast.FunctionDef): + # Visit args and register load nodes + for arg in node.args.args: + self.visit(arg) + for expr in node.body: + self.visit(expr) + + def visit_arg(self, node: ast.arg): + # Name of the argument + name = node.arg + try: + example_tensor = self.example_inputs[name] + except: + raise RuntimeError(f"Example input for {name} is not provided.") + + self.add_load_node(name, example_tensor) + + def visit_Name(self, node: ast.Name): + return node.id + + def visit_Constant(self, node: ast.Constant): + if self.no_imm: + return node.value + else: + name = self.add_imm(node.value) + return name + + def visit_Tuple(self, node: ast.Tuple): + results = [] + for elt in node.elts: + results.append(self.visit(elt)) + return tuple(results) + + def visit_keyword(self, node: ast.keyword): + return {node.arg: self.visit(node.value)} + + def visit_BinOp(self, node: ast.BinOp): + if self.visiting_return: + raise SyntaxError("Return value cannot be an expression") + lhs = self.visit(node.left) + rhs = self.visit(node.right) + op = self.ast_op_to_bindings(type(node.op)) + name = self.add_compute_node(op) + + # Add edges + # The edge weights are used to sort the input args + self.add_edge(lhs, name, weight=0) + self.add_edge(rhs, name, weight=1) + return name + + def visit_Assign(self, node: ast.BinOp): + target = self.visit(node.targets[0]) + value = self.visit(node.value) + # Create the assign node + self.add_store_node(target) + + # Add edges + self.add_edge(value, target) + return target + + def visit_Call(self, node: ast.Call): + if self.visiting_return: + raise SyntaxError("Return value cannot be an expression") + func = self.visit(node.func) + args = [self.visit(arg) for arg in node.args] + + if func in self.layout_fns.keys(): + # Parse kwargs + # By default, visiting imm automatically creates a load node + # However, in function call, keyword args are used to set + # specific function attributes such as indices for permute + # So no_imm is set to True temporarily + self.no_imm = True + kwargs = {} + for kw in node.keywords: + kwargs.update(self.visit(kw)) + self.no_imm = False + op = self.layout_fns[func] + name = self.add_layout_node(op, kwargs) + else: + op = self.ast_op_to_bindings(func) + name = self.add_compute_node(op) + + # Add edges + for idx, arg in enumerate(args): + self.add_edge(arg, name, weight=idx) + return name + + def visit_Return(self, node: ast.Return): + self.visiting_return = True + results = self.visit(node.value) + self.visiting_return = False + self.return_names = results + if not isinstance(results, tuple): + results = (results,) + for rst in results: + try: + example_tensor = self.example_inputs[rst] + except: + raise RuntimeError(f"Example input for {rst} is not provided.") + self.set_store_tensor(rst, example_tensor) + self.mark_output(rst) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f9e3f811a020164dc5ec5eb4a8dfaf3dc5728fe --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py @@ -0,0 +1,53 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl +from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR +from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode +from cutlass_cppgen.backend.evt.ir.load_nodes import ( + LoadNode, + AccumulatorImpl, + LoadSrcImpl, + AuxLoadImpl, + RowBroadcastImpl, + ColumnBroadcastImpl, + ScalarBroadcastImpl +) +from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl +from cutlass_cppgen.backend.evt.ir.store_nodes import ( + StoreNode, + StoreDImpl, + AuxStoreImpl, + ColumnReductionImpl, + RowReductionImpl, + ScalarReductionImpl +) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..02b05358648694dcf2a5afd7117e6fca6a2d136c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py @@ -0,0 +1,91 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Python registration for compute nodes in EVT +""" + +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase +from cutlass_cppgen.backend.library import FloatRoundStyle + + +class ComputeImplBase(ImplBase): + """ + Base class for compute implementation + """ + def __init__(self, node) -> None: + super().__init__(node) + + +class ComputeImpl(ComputeImplBase): + """ + Implementation for Compute Node + """ + def __init__(self, node) -> None: + super().__init__(node) + + self.fn = node.fn + self.element_output = node.element_output + self.element_compute = node.element_compute + self.round_style = node.round_style + + @staticmethod + def match(node, problem_size: tuple): + return True + + +class ComputeNode(NodeBase): + """ + Compute Node in DAG IR + """ + possible_impls = [ + ComputeImpl + ] + def __init__( + self, name: str, fn, element_output, + element_compute, + round_style=FloatRoundStyle.ToNearest) -> None: + super().__init__(name) + self.op = "compute" + self.fn = fn + self.element_compute = element_compute + self.round_style = round_style + + def type_propagation(self, *args, **kwargs): + """ + Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`. + """ + self.element = self.element_compute + # In general, the compute nodes have element_output = element_compute + # In certain cases like producer of D it is overwritten by other passes + if not hasattr(self, "element_output"): + self.element_output = self.element diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e9f75a9727306d56c049bd491a95542a68bec8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py @@ -0,0 +1,254 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +DAG IR used by Python EVT +""" + +import networkx as nx + +from cutlass_library import DataType + +from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.library import ActivationOp +from cutlass_cppgen.backend.utils import device_cc + + +class DAGIR: + """ + ``DAGIR`` is the main data structure used in the EVT Intermediate Representation. + It consists of a series of ``Node`` s, each representing epilogue visitor nodes. + + In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node + """ + def __init__(self, cc, element_compute=DataType.f32) -> None: + # The EVT DAGIR is managed through the nextworkX Digraph class + self._graph = nx.DiGraph() + + self.element_compute = element_compute + + self.reduction_names = [] + + self.cc = cc + + self.identity_counter = 0 + + # + # IR manipulator + # + + def add_node(self, meta: NodeBase): + """ + Add a node to dag ir + """ + if self.has_node(meta.name): + raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.") + self._graph.add_node(meta.name, meta=meta) + + def add_edge(self, src: str, dst: str, weight: int=0): + """ + Add an edge src -> dst to dag ir with weight + """ + if not self.has_node(src): + raise SyntaxError(f"Variable '{src}' is undefined.") + if not self.has_node(dst): + raise SyntaxError(f"Variable '{dst}' is undefined.") + + if self._graph.has_edge(src, dst): + # The DiGraph doesn't support multiple edges between two nodes + # We insert an identity node in such case as a workaround + identity_name = f"autogen_identity_{self.identity_counter}" + self.identity_counter += 1 + compute_node = ComputeNode( + name=identity_name, fn=ActivationOp.Identity, + element_output=self.element_compute, + element_compute=self.element_compute) + self.add_node(compute_node) + self.add_edge(src, identity_name, 0) + self.add_edge(identity_name, dst, weight) + else: + self._graph.add_edge(src, dst, weight=weight) + + def remove_node(self, node: str): + """ + Remove node from dag ir + """ + self._graph.remove_node(node) + + def remove_edge(self, src: str, dst: str): + """ + Remove edge src -> dst + """ + self._graph.remove_edge(src, dst) + + # + # Helper functions for getting attrs + # + + def has_node(self, node: str) -> bool: + """ + Check if the node is in the graph + """ + return self._graph.has_node(node) + + def in_degree(self, node: str): + """ + Get the input degree of node + """ + return self._graph.in_degree(node) + + def in_edges(self, node: str): + """ + Get the input edges of node + """ + return [edge for edge in self._graph.in_edges(node)] + + def out_degree(self, node: str): + """ + Get the output degree of node + """ + return self._graph.out_degree(node) + + def out_edges(self, node: str): + """ + Get the output edges of node + """ + return [edge for edge in self._graph.out_edges(node)] + + def get_node_meta(self, node: str): + """ + Get the meta data of the node + """ + return self._graph.nodes[node]["meta"] + + def get_edge_weight(self, src, dst): + """ + Get the edge weight of edge src->dst + """ + return self._graph.get_edge_data(src, dst)["weight"] + + # + # High-level helper functions + # + + def all_reachable_nodes(self, node: str): + """ + Get all the nodes reachable from the current node (exclude) + """ + return list(nx.dfs_preorder_nodes(self._graph, source=node)) + + def get_users(self, node: str): + """ + Get all users of the current node + """ + return [edge[1] for edge in self.out_edges(node)] + + def get_all_inputs(self, node: str): + """ + Get all the input nodes sorted by edge weight + """ + in_edges = self.in_edges(node) + edge_weights = [self.get_edge_weight(*edge) for edge in in_edges] + return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))] + + def get_all_inputs_meta(self, node: str): + """ + Get all the input node metas sorted by edge weight + """ + return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)] + + def replace_all_uses_with(self, node1, node2): + """ + Replace all uses of node1 with node2 + """ + for edge in self.out_edges(node1): + weight = self.get_edge_weight(*edge) + user = edge[1] + self.add_edge(node2, user, weight) + self.remove_edge(node1, user) + self.remove_node(node1) + + # + # Node accessor + # + def nodes_topological_order(self): + """ + Get the nodes in the unique lexicographical topological order + It generates a unique ordering of nodes by first sorting topologically + and then additionally by sorting lexicographically. + + Although topological_sort alone also works, this generates a unique key + for each epilogue visitor pattern and ensures the compilation cache can be reused. + :return: list[str] + """ + return list(nx.lexicographical_topological_sort(self._graph)) + + def node_metas_topological_order(self): + """ + Get the node metas in topological order + :return: list[NodeBase] + """ + return [self.get_node_meta(node) for node in self.nodes_topological_order()] + + @property + def nodes(self): + """ + Get all nodes + :return: list[str] + """ + return list(self._graph.nodes) + + @property + def nodes_meta(self): + """ + Get all node metas + :return: list[NodeBase] + """ + return [data[1]['meta'] for data in self._graph.nodes.data()] + + @property + def edges(self): + """ + Get all edges + :return: list[(str, str)] + """ + return list(self._graph.edges) + + # + # Path + # + def has_path(self, src: str, target: str) -> bool: + """ + Return True is a path exists from src to target + """ + return nx.has_path(self._graph, src, target) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..9d453b1f4c41d002297c5348cbed8fd7f0ef3081 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py @@ -0,0 +1,324 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Layout algebras +""" + +from pycute import Layout, composition, make_layout, flatten, product + + +def _infer_split(old_shape, new_shape): + old_shape = _tuple_to_list(old_shape) + new_shape = _tuple_to_list(new_shape) + if len(old_shape) == 0 and len(new_shape) == 0: + return [] + if len(old_shape) == 0: + if product(tuple(new_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return new_shape + if len(new_shape) == 0: + if product(tuple(old_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return old_shape + # This is done recursively by only process the last dimension at each time + old_dim = old_shape[-1] + new_dim = new_shape[-1] + # Exact match + if old_dim == new_dim: + return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,] + # Needs split + if old_dim > new_dim and old_dim % new_dim == 0: + residual = old_dim // new_dim + return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,] + # Needs merge + if old_dim < new_dim and new_dim % old_dim == 0: + residual = new_dim // old_dim + return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,] + + raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}") + +def _infer_merge(flatten_shape, shape): + flatten_shape = _tuple_to_list(flatten_shape) + shape = _tuple_to_list(shape) + idx_flat = 0 + merged_shape = [] + for dim in shape: + # Exact match + if dim == flatten_shape[idx_flat]: + merged_shape.append(dim) + idx_flat += 1 + # Need group + elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: + residual = dim + group = [] + while(residual > 1): + group.append(flatten_shape[idx_flat]) + residual = residual // flatten_shape[idx_flat] + idx_flat += 1 + merged_shape.append(group) + else: + raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") + + return merged_shape + +def _list_to_tuple(nested_list): + if isinstance(nested_list, list) or isinstance(nested_list, tuple): + return tuple(_list_to_tuple(item) for item in nested_list) + return nested_list + +def _tuple_to_list(nested_tuple): + if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple): + return list(_tuple_to_list(item) for item in nested_tuple) + return nested_tuple + +def _reverse_tuple(nested_tuple: tuple): + if isinstance(nested_tuple, tuple): + return tuple([_reverse_tuple(item) for item in nested_tuple][::-1]) + return nested_tuple + +def _get_first_lhs_nonzero_stride(stride_list, idx): + for i in reversed(range(idx)): + if stride_list[i] != 0: + return i + else: + return None + +def _get_first_rhs_nonzero_stride(stride_list, idx): + for i in range(idx+1, len(stride_list)): + if stride_list[i] != 0: + return i + else: + return None + +def reshape(layout, new_shape): + """ + General reshape of input layout. + It takes two steps: + 1. split the dimensions of the old layout + 2. merge the splitted dimensions according to the new shape + """ + # + # Step 1: Split the dimensions of the old layout + # + # 1.1 Flat old and new shape + old_flatten_shape = list(flatten(layout.shape)) + new_flatten_shape = list(flatten(new_shape)) + + # 1.2 Infer the flatten splitted shape + splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape) + + # 1.3 Unflat the splitted shape based on the old shape + splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape) + + # 1.4 Infer the type of each split + # If the split type is in row-major (R), the dimension list is reversed because + # the cute::composition only support column-major split + split_type = [] # the type of each split (ColumnMajor or RowMajor) + permuted_splitted_shape = [] + old_flatten_stride = list(flatten(layout.stride)) + for idx, dim in enumerate(splited_shape): + if not isinstance(dim, list): + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx) + rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx) + # Special case for single tuple + # Use column-major by default + if lhs_stride is None and rhs_stride is None: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + if lhs_stride is not None and rhs_stride is not None: + # We consider shape[idx]:stride[idx] + # Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major + if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: + permuted_splitted_shape.append(dim) + split_type.append("C") + # Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major + elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + # Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave + elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride: + if lhs_stride >= rhs_stride: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + # Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave + elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride: + if lhs_stride >= rhs_stride: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + else: + raise NotImplementedError() + elif lhs_stride is None: + # Case 1: dim's stride < dim+1's stride, expand in column major + if old_flatten_stride[idx] > rhs_stride: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + else: + permuted_splitted_shape.append(dim) + split_type.append("C") + else: + # Case 1: dim's stride > dim-1's stride + if old_flatten_stride[idx] < lhs_stride: + permuted_splitted_shape.append([d for d in reversed(dim)]) + split_type.append("R") + else: + permuted_splitted_shape.append(dim) + split_type.append("C") + + # 1.4 Generate the splitted layout + permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape))) + + # 1.5 Reverse the permutation in 1.4 before merge + splitted_shape = [] + splitted_stride = [] + for shape_dim, stride_dim, type in zip( + permuted_splitted_layout.shape, + permuted_splitted_layout.stride, + split_type): + if type == "C": + splitted_shape.append(shape_dim) + splitted_stride.append(stride_dim) + else: + splitted_shape.append(tuple([d for d in reversed(shape_dim)])) + splitted_stride.append(tuple([d for d in reversed(stride_dim)])) + splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride)) + + + # + # Step 2: Merge the splitted dimensions according to the new shape + # + # 2.1 Merge layout + merged_layout = composition(splitted_layout, Layout(new_shape)) + + # 2.2 Cleaning up + output_layout = composition(merged_layout, Layout(new_shape)) + return output_layout + + +def permutation(layout, permutation): + """ + Permute the layout + """ + new_shape = tuple([layout.shape[idx] for idx in permutation]) + new_stride = tuple([layout.stride[idx] for idx in permutation]) + return Layout(new_shape, new_stride) + + +def _broadcast(layout, new_shape): + if len(layout) == 1 and isinstance(new_shape, int): + old_dim = layout.shape + old_stride = layout.stride + new_dim = new_shape + if old_dim == new_dim: + return Layout(old_dim, old_stride) + elif old_dim == 1: + return Layout(new_dim, 0) + else: + raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}") + + # Align the dimensions + old_shape = layout.shape + if isinstance(old_shape, int): + old_shape = (old_shape,) + sub_layouts = [layout,] + else: + sub_layouts = [sub_layout for sub_layout in layout] + rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape)) + # Get the broadcasted layout + broadcast_layouts = [] + try: + layout = make_layout(*sub_layouts, *rhs_broadcast_layouts) + broadcast_layouts = [] + for idx, sub_layout in enumerate(layout): + broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) + except NotImplementedError: + layout = make_layout(*rhs_broadcast_layouts, *sub_layouts) + for idx, sub_layout in enumerate(layout): + broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx])) + return make_layout(*broadcast_layouts) + + +def broadcast(layout, new_shape): + """ + Broadcast the new layout based on the input shape + The broadcasted shape equals to the new shape + The stride of broadcasted dimensions are 0 + """ + return _broadcast(layout, new_shape) + + +def debroadcast(layout, dims): + """ + Squeeze the 0-stride + """ + for dim in dims: + if layout.stride[dim] != 0: + raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}") + new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims]) + new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims]) + return Layout(new_shape, new_stride) + + +def canonicalization_(shapes, strides): + if isinstance(shapes, tuple): + c_shapes = [] + c_strides = [] + for shape, stride in zip(shapes, strides): + c_shape, c_stride = canonicalization_(shape, stride) + c_shapes.append(c_shape) + c_strides.append(c_stride) + return tuple(c_shapes), tuple(c_strides) + else: + if shapes == 1: + return 1, 0 + else: + return shapes, strides + +def canonicalization(layout): + """ + Canonicalize the input layout + 1. set the stride of shape "1" to 0 + """ + new_shape, new_stride = canonicalization_(layout.shape, layout.stride) + return Layout(new_shape, new_stride) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..1095e2ab1d956399b5e27ddaf140e53d9918ec26 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py @@ -0,0 +1,336 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Layout manipulation nodes and implementations + +The layout Nodes change the layout of intermediate nodes in epilogue visitor graph +""" + +from copy import deepcopy + +from cutlass_library import LayoutType +from pycute import product, flatten + +import cutlass_cppgen +from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.evt.ir.tensor import Tensor + + +class PermutationImpl: + """ + Detailed implementation and helper functions for permutation + """ + def __init__(self, node) -> None: + assert "indices" in node.kwargs.keys() + self.indices = list(node.kwargs["indices"]) + self.inverse_indices = self.get_inverse_indices(self.indices) + + def get_inverse_impl(self): + inverse_impl = deepcopy(self) + inverse_impl.indices = self.inverse_indices + inverse_impl.inverse_indices = self.indices + return inverse_impl + + def update(self, shape): + num_dim = len(shape) + indices = self.indices + num_old_dim = len(indices) + # Add offset + for i, idx in enumerate(indices): + indices[i] = idx + num_dim - num_old_dim + # Add broadcast dims + for i in range(num_dim - num_old_dim): + indices = [i,] + indices + + self.indices = indices + self.inverse_indices = self.get_inverse_indices(self.indices) + + def get_inverse_indices(self, indices): + """ + Get the indices for inverse permutation + """ + num_dim = len(indices) + inverse_indices = [0] * num_dim + for i in range(num_dim): + inverse_indices[indices[i]] = i + return inverse_indices + + def shape_propagation(self, input_node_meta): + input_shape = input_node_meta.tensor.shape + output_shape = tuple([input_shape[idx] for idx in self.indices]) + return output_shape + + def broadcast(self, shape, node_meta: NodeBase): + """ + Broadcast the inputs based on current shape + """ + self.update(shape) + inverse_shape = tuple([shape[idx] for idx in self.inverse_indices]) + node_meta.tensor.broadcast(inverse_shape) + + def apply_to_user(self, usr_meta: NodeBase): + """ + Propagate the permutation to the users of the current nodes + """ + usr_meta.tensor.permute(self.inverse_indices) + if hasattr(usr_meta, "store_tensor"): + if usr_meta.store_tensor is not None: + usr_meta.store_tensor.permute(self.inverse_indices) + + def apply_to_input(self, input_meta: NodeBase): + """ + Propagate the permutation to inputs of the current nodes + """ + input_meta.tensor.permute(self.indices) + if hasattr(input_meta, "store_tensor"): + if input_meta.store_tensor is not None: + input_meta.store_tensor.permute(self.indices) + + +class ReshapeImpl: + """ + Detailed implementation and helper functions for reshape + """ + def __init__(self, node) -> None: + self.node = node + assert "new_shape" in node.kwargs.keys() + self.output_shape = _list_to_tuple(node.kwargs["new_shape"]) + + def get_inverse_impl(self): + inverse_impl = deepcopy(self) + inverse_impl.output_shape = self.input_shape + inverse_impl.input_shape = self.output_shape + return inverse_impl + + def shape_propagation(self, input_node_meta): + self.input_shape = input_node_meta.tensor.shape + return _list_to_tuple(self.output_shape) + + def broadcast(self, shape, node_meta: NodeBase): + """ + Broadcast the inputs based on current shape. + """ + # Step 1: infer split + flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape)) + split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape) + split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape) + + # broadcast shape -> split_output_shape -> flatten_split_shape + if len(shape) - len(split_output_shape) > 0: + for _ in range(len(shape) - len(split_output_shape)): + split_output_shape = [1,] + split_output_shape + flatten_split_shape = [1,] + flatten_split_shape + split_input_shape = [1,] + split_input_shape + broadcast_factor = [] + for dim, old_dim in zip(shape, split_output_shape): + if not isinstance(dim, list): + dim = [dim,] + if not isinstance(old_dim, list): + old_dim = [old_dim,] + if product(tuple(dim)) == product(tuple(old_dim)): + broadcast_factor += [1] * len(old_dim) + elif product(tuple(old_dim)) == 1: + assert len(dim) == 1 + broadcast_factor.append(dim[0]) + else: + raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}") + + # flatten_split_shape -> split_input_shape + factor_idx = 0 + broadcast_split_input_shape = [] + for dim in split_input_shape: + if isinstance(dim, list): + new_dim = [] + for d in dim: + new_dim.append(d * broadcast_factor[factor_idx]) + factor_idx += 1 + broadcast_split_input_shape.append(new_dim) + else: + broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx]) + factor_idx += 1 + broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape) + node_meta.tensor.reshape(_list_to_tuple(split_input_shape)) + node_meta.tensor.broadcast(broadcast_split_input_shape) + # Last reshape op to clean up + broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape]) + node_meta.tensor.reshape(broadcast_input_shape) + # Update the input shape and output shape + self.input_shape = _list_to_tuple(node_meta.tensor.shape) + self.output_shape = _list_to_tuple(shape) + + def apply_to_user(self, user_meta: NodeBase): + """ + Propagate the reshape to user nodes + """ + user_meta.tensor.reshape(tuple(self.input_shape)) + if hasattr(user_meta, "store_tensor"): + if user_meta.store_tensor is not None: + user_meta.store_tensor.reshape(tuple(self.input_shape)) + + def apply_to_input(self, input_meta: NodeBase): + """ + Propagate the reshape to input nodes + """ + input_meta.tensor.reshape(tuple(self.output_shape)) + if hasattr(input_meta, "store_tensor"): + if input_meta.store_tensor is not None: + input_meta.store_tensor.reshape(tuple(self.output_shape)) + + # + # Helper functions + # + + def infer_split(self, input_shape, output_shape): + """ + Infer the flatten splitted shape that can be merged to both input_shape and output_shape + """ + input_shape = _tuple_to_list(input_shape) + output_shape = _tuple_to_list(output_shape) + if len(input_shape) == 0 and len(output_shape) == 0: + return [] + if len(input_shape) == 0: + if product(tuple(output_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return output_shape + if len(output_shape) == 0: + if product(tuple(input_shape)) != 1: + raise ValueError("Invalid reshape size") + else: + return input_shape + # This is done recursively by only process the last dimension at each time + old_dim = input_shape[-1] + new_dim = output_shape[-1] + # Exact match + if old_dim == new_dim: + return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,] + # Needs split + if old_dim > new_dim and old_dim % new_dim == 0: + residual = old_dim // new_dim + return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,] + # Needs merge + if old_dim < new_dim and new_dim % old_dim == 0: + residual = new_dim // old_dim + return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,] + + raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}") + + def infer_merge(self, flatten_shape, shape): + flatten_shape = _tuple_to_list(flatten_shape) + shape = _tuple_to_list(shape) + idx_flat = len(flatten_shape) - 1 + merged_shape = [] + for dim in reversed(shape): + # Exact match + if dim == flatten_shape[idx_flat]: + merged_shape.append(dim) + idx_flat -= 1 + # need group + elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0: + residual = dim + group = [] + while(residual > 1): + group.append(flatten_shape[idx_flat]) + residual = residual // flatten_shape[idx_flat] + idx_flat -= 1 + merged_shape.append(group[::-1]) + else: + raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}") + + return merged_shape[::-1] + + +class LayoutNode(NodeBase): + """ + Layout manipulation nodes + """ + fn_to_impl = { + "permute": PermutationImpl, + "reshape": ReshapeImpl + } + def __init__(self, name: str, fn, kwargs: dict) -> None: + super().__init__(name) + self.op = "layout" + self.fn = fn + self.kwargs = kwargs + self.underlying_impl = self.fn_to_impl[self.fn.__name__](self) + + def get_inverse_node(self): + inverse_node = deepcopy(self) + inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl() + return inverse_node + + def shape_propagation(self, input_node_metas): + if self._tensor is not None: + return + assert len(input_node_metas) == 1, "Layout node can only have one input node" + + output_shape = self.underlying_impl.shape_propagation(input_node_metas[0]) + + self._tensor = Tensor( + element=self.element_output, + shape=output_shape, layout_tag=LayoutType.RowMajor + ) + + return super().shape_propagation(input_node_metas) + + def type_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + The store nodes has element_output = element_input + """ + assert len(input_node_metas) == 1, "Layout node can only have one input node" + self.element_output = input_node_metas[0].element_output + + def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + Propagate the broadcast in the reversed topological order + """ + if self.tensor is None: + raise RuntimeError(f"The tensor of node {self.name} is unknown.") + shape = self.tensor.shape + + for child in input_node_metas: + self.underlying_impl.broadcast(shape, child) + + def apply_to_user(self, usr_meta: NodeBase): + """ + Propagate the permutation to user nodes + """ + self.underlying_impl.apply_to_user(usr_meta) + + def apply_to_input(self, input_meta: NodeBase): + """ + Propagate the permutation to input nodes + """ + self.underlying_impl.apply_to_input(input_meta) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..bff0aaa2c21ef2545f50745cdb33499270eeb9fb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py @@ -0,0 +1,294 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Load nodes and implementations +""" + +import ctypes + +from cutlass_cppgen.backend.c_types import tuple_factory +from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase + + +class LoadImplBase(ImplBase): + """ + Base class for load node implementations + """ + reserved_names = ["accum", "C"] + def __init__(self, node) -> None: + super().__init__(node) + self.element = node.element + self.element_output = node.element_output + self.stride = node.tensor.stride + + +class AccumulatorImpl(LoadImplBase): + """ + Accumulator node implementation + """ + + @staticmethod + def match(node, problem_size: tuple): + return node.name == "accum" and node.tensor.shape == problem_size + + +class LoadSrcImpl(LoadImplBase): + """ + Load C implementation + """ + @property + def name_camel(self) -> str: + return "TensorC" + + @property + def argument_type_c(self): + stride_mnl = self.get_stride_mnl() + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_C", ctypes.c_void_p), + ("stride_C", tuple_type) + ] + def __init__(self, ptr) -> None: + self.ptr_C = ptr + self.stride_C = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + return node.name == "C" and node.tensor.shape == problem_size + + +class AuxLoadImpl(LoadImplBase): + """ + Load arbitrary tensor + """ + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_aux", ctypes.c_void_p), + ("null_default", dtype2ctype[element_type]), + ("dAux", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_aux = ptr + self.null_default = to_ctype_value(0, element_type) + self.dAux = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + strideMN = node.tensor.stride[-2:] + if (strideMN[0] == 1 and strideMN[1] != 0 or + strideMN[0] != 0 and strideMN[1] == 1 ): + return True + else: + return False + + +class RowBroadcastImpl(LoadImplBase): + """ + Broadcast a row vector + """ + def __init__(self, node) -> None: + super().__init__(node) + self.stride_dtype = "int" + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_row", ctypes.c_void_p), + ("null_default", dtype2ctype[element_type]), + ("dRow", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_row = ptr + self.null_default = to_ctype_value(0, element_type) + self.dRow = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + + strideMN = node.tensor.stride[-2:] + if strideMN == (0, 1): + return True + else: + return False + + +class ColumnBroadcastImpl(LoadImplBase): + """ + Broadcast a column vector + """ + def __init__(self, node) -> None: + super().__init__(node) + self.stride_dtype = "int" + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_col", ctypes.c_void_p), + ("null_default", dtype2ctype[element_type]), + ("dCol", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_col = int(ptr) + self.null_default = to_ctype_value(0, element_type) + self.dCol = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + + strideMN = node.tensor.stride[-2:] + if strideMN == (1, 0): + return True + else: + return False + + +class ScalarBroadcastImpl(LoadImplBase): + """ + Broadcast a scalar + """ + def __init__(self, node) -> None: + super().__init__(node) + self.stride_dtype = "int" + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_type = self.element + + if self.tensor.is_constant: + value = self.tensor.value + class _Argument(ctypes.Structure): + _fields_ = [ + ("scalars", dtype2ctype[element_type]), + ("scalar_ptrs", ctypes.c_void_p), + ("dScalar", tuple_type) + ] + def __init__(self, kwargs) -> None: + self.scalars = to_ctype_value(value, element_type) + self.scalar_ptrs = 0 + self.dScalar = tuple_type(stride_mnl) + + else: + class _Argument(ctypes.Structure): + _fields_ = [ + ("scalars", dtype2ctype[element_type]), + ("scalar_ptrs", ctypes.c_void_p), + ("dScalar", tuple_type) + ] + def __init__(self, kwargs) -> None: + scalar_or_ptr = kwargs[name] + if isinstance(scalar_or_ptr, float): + self.scalars = to_ctype_value(scalar_or_ptr, element_type) + self.scalar_ptrs = 0 + else: + self.scalar_ptrs = int(scalar_or_ptr) + + self.dScalar = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name in LoadImplBase.reserved_names: + return False + + strideMN = node.tensor.stride[-2:] + if strideMN == (0, 0): + return True + else: + return False + + +class LoadNode(NodeBase): + """ + Load Node + """ + cnt = 0 + possible_impls = [ + AccumulatorImpl, LoadSrcImpl, AuxLoadImpl, + RowBroadcastImpl, ColumnBroadcastImpl, + ScalarBroadcastImpl + ] + def __init__(self, name: str) -> None: + if name is None: + name = f"load{LoadNode.cnt}" + LoadNode.cnt += 1 + super().__init__(name) + self.op = "load" + + def type_propagation(self, *args, **kwargs): + """ + Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`. + """ + if self.tensor is None: + raise RuntimeError(f"The tensor of node {self.name} is unknown.") + + self.element = self.tensor.element + self.element_output = self.tensor.element diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py new file mode 100644 index 0000000000000000000000000000000000000000..606591b8e78c97114b85b329050d630d55460d7a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py @@ -0,0 +1,306 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base & visitor classes of DAGIR Nodes +""" + +import ctypes +from re import sub + +from cutlass_library import LayoutType + +from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple +from cutlass_cppgen.backend.evt.ir.tensor import Tensor + + +class TupleEmitter: + """ + Emit the cute tuple to C++ code + """ + def __init__(self, stride_dtype): + self.stride_dtype = stride_dtype + + def emit(self, py_tuple): + if isinstance(py_tuple, int): + if py_tuple in [0, 1]: + return f"cute::Int<{py_tuple}>" + else: + return f"{self.stride_dtype}" + elif isinstance(py_tuple, tuple): + decl = "cute::Stride<" + for item in py_tuple: + decl += self.emit(item) + ", " + return decl[:-2] + ">" + else: + raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}") + + +class ImplBase: + """ + Base class for Node Implementation + """ + def __init__(self, node) -> None: + self.node = node + self.name = node.name + self.tensor = node.tensor + self._type_decl = None + self.tuple_emitter = TupleEmitter("int64_t") + + @property + def stride_dtype(self): + return self.tuple_emitter.stride_dtype + + @stride_dtype.setter + def stride_dtype(self, stride_dtype): + self.tuple_emitter.stride_dtype = stride_dtype + + @staticmethod + def match(node, problem_size: tuple): + """ + Match function used in get_underlying_impl + """ + raise NotImplementedError(f"The `match` function is not defined.") + + @property + def argument_type(self): + """ + Default class for Argument Type + """ + class _Argument(ctypes.Structure): + _fields_ = [] + + def __init__(self, *args, **kwargs) -> None: + pass + + return _Argument + + @property + def name_camel(self) -> str: + """ + Return the CamelCase name. + """ + return sub(r"(_|-)+", " ", self.name).title().replace(" ", "") + + @property + def stride_mnl(self): + """ + Typename StrideMNL + """ + stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) + return self.tuple_emitter.emit(stride) + + def get_non_constant_stride(self, py_tuple): + if isinstance(py_tuple, int): + if py_tuple not in [0, 1]: + return py_tuple + else: + return None + non_constant_stride = [] + for item in py_tuple: + item_out = self.get_non_constant_stride(item) + if item_out: + non_constant_stride.append(item_out) + return tuple(non_constant_stride) + + def get_stride_mnl(self): + """ + Get the non-zero stride mnl. This is used in argument construction + """ + stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2])))) + return stride + + def get_smem_size(self, *args, **kwargs): + """ + Get the shared memory size and alignment of current node + """ + return (0, 1) + + +class NoOpImpl(ImplBase): + """ + The NoOpImpl does nothing but forward its input to users + """ + def __init__(self, node) -> None: + super().__init__(node) + + @staticmethod + def match(node, problem_size: tuple): + if node.op == "store": + # Store that is not output is a No OP + return not node.is_output + + +class NodeBase: + """ + Base class of DAG Node + """ + def __init__(self, name: str) -> None: + self.name = name + self.underlying_impl = None + + self._tensor = None + + # Whether the node is disabled for emit + self.disabled = False + + @property + def name_camel(self) -> str: + """ + Return the CamelCase name. + """ + return self.underlying_impl.name_camel + + @property + def tensor(self) -> Tensor: + """ + Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) + """ + return self._tensor + + @tensor.setter + def tensor(self, kwargs): + """ + Setting the tensor + """ + self._tensor = Tensor(**kwargs) + + # + # Helper functions for type/shape propagation + # + + def shape_propagation(self, input_node_metas): + """ + Infer shape from input nodes + General Broadcasting Rules from NumPy + When operating on two arrays, we compare their shapes element-wise. + It starts with the trailing (i.e. rightmost) dimension and works its + way left. Two dimensions are compatible when + 1. they are equal + 2. one of them is 1 + """ + if self._tensor is not None: + return + + shape = None + for src in input_node_metas: + src_shape = src.tensor.shape + if shape is None: + shape = src_shape + else: + len_difference = len(shape) - len(src_shape) + if len_difference > 0: + for _ in range(len_difference): + src_shape = [1, ] + list(src_shape) + elif len_difference < 0: + for _ in range(-len_difference): + shape = [1, ] + list(shape) + broadcasted_shape = [] + # Infer broadcast shape + for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)): + if shape_dim == 1: + broadcasted_shape = [src_dim, ] + list(broadcasted_shape) + elif src_dim == 1: + broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) + elif shape_dim == src_dim: + broadcasted_shape = [shape_dim, ] + list(broadcasted_shape) + else: + error_msg = "Dimension mismatch between " + for src_ in input_node_metas: + error_msg += f"{src_.name}{src_.tensor.shape}, " + error_msg = error_msg[:-2] + "." + raise RuntimeError(error_msg) + shape = tuple(broadcasted_shape) + + self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor) + + def type_propagation(self, *args, **kwargs): + """ + Each node is associated with two data types: `element` and `element_output`. + The `element_output` is the type of return array of the node. The `element` + has specific meaning for different node types. + * Load Node: data type of tensor in gmem + * Compute Node: element compute + * Store Node: data type of tensor in gmem + This function must be overloaded in the derived classes + """ + raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}") + + def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + Propagate the broadcast in the reversed topological order. + For example: + C[l, m, n] = A[m, 1] + B[l, m, n] + After the broadcast propagation, it will be come + C[l, m, n] = A[l, m, n] + B[l, m, n] + and each tensor will have a proper stride accessing the underlying tensor + """ + if self.tensor is None: + raise RuntimeError(f"The tensor of node {self.name} is unknown.") + for child in input_node_metas: + child.tensor.broadcast(self.tensor.shape) + + def get_underlying_impl(self, problem_size: tuple): + """ + Get the underlying implementation of the current node. + """ + if self.tensor is None: + raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.") + + for impl in self.possible_impls: + if impl.match(self, problem_size): + self.underlying_impl = impl(self) + break + + if self.underlying_impl is None: + raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.") + +# +# Visitor Nodes & Impls +# + +class TopoVisitorImpl(ImplBase): + """ + Impl for topological visitor + """ + def __init__(self, node) -> None: + super().__init__(node.output_node) + self.name = node.name + self.element_output = node.output_node.element_output + +class TopoVisitorNode(NodeBase): + def __init__(self, name: str, subgraph, output_node) -> None: + super().__init__(name) + self.subgraph = subgraph + self.output_node = output_node + self.op = "dag" + self.underlying_impl = TopoVisitorImpl(self) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..708405e0647ca3cb22bd0c1d4770d71810a469e2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py @@ -0,0 +1,277 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Store node and implementations +""" + +import ctypes + +from cutlass_library import DataType + +from cutlass_cppgen.backend.c_types import tuple_factory +from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value +from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl +from cutlass_cppgen.backend.evt.ir.tensor import Tensor +from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp + + +class StoreImplBase(ImplBase): + """ + Base class for store node implementation + """ + reserved_names = ["D"] + def __init__(self, node) -> None: + super().__init__(node) + self.element = node.element + self.element_output = node.element_output + self.stride = node.store_tensor.stride + + +class StoreDImpl(StoreImplBase): + """ + Store D implementation + """ + + @property + def argument_type_d(self): + stride_mnl = self.get_stride_mnl() + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_D", ctypes.c_void_p), + ("stride_D", tuple_type) + ] + def __init__(self, ptr: int) -> None: + self.ptr_D = ptr + self.stride_D = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if node.name == "D" and node.store_tensor.shape == problem_size: + return True + return False + + +class AuxStoreImpl(StoreImplBase): + def __init__(self, node) -> None: + super().__init__(node) + self.round_style = FloatRoundStyle.ToNearest + + @property + def argument_type(self): + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr_aux", ctypes.c_void_p), + ("dAux", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr_aux = ptr + self.dAux = tuple_type(stride_mnl) + + return _Argument + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if (strideMN[0] == 1 and strideMN[1] != 0 or + strideMN[0] != 0 and strideMN[1] == 1 ): + return True + else: + return False + + +class ReductionImplBase(StoreImplBase): + def __init__(self, node) -> None: + super().__init__(node) + self.element = node.store_tensor.element + self.element_compute = node.element_compute + self.reg_reduce_fn = self.node.reg_reduce_fn + self.gmem_reduce_fn = self.node.gmem_reduce_fn + self.round_style = node.round_style + self.stride_dtype = "int" + + def get_reduce_identity(self): + """ + Return the reduction identity of the current reduce_fn + """ + maxes = { + DataType.f32: (2 ** 31) - 1, + DataType.f16: (2 ** 15), + DataType.s32: (2 ** 31) - 1, + DataType.s8: (2 ** 7) - 1 + } + mins = { + DataType.f32: -maxes[DataType.f32], + DataType.f16: -maxes[DataType.f16], + DataType.s32: -maxes[DataType.s32], + DataType.s8: -maxes[DataType.s8] + } + if self.reg_reduce_fn == FunctionalOp.Maximum: + if self.element_compute not in mins: + raise Exception(f"No min entry for data type {self.element_compute}") + return to_ctype_value(mins[self.element_compute], self.element_compute) + elif self.reg_reduce_fn == FunctionalOp.Multiplies: + return to_ctype_value(1., self.element_compute) + elif self.reg_reduce_fn == FunctionalOp.Minimum: + if self.element_compute not in maxes: + raise Exception(f"No max entry for data type {self.element_compute}") + return to_ctype_value(maxes[self.element_compute], self.element_compute) + else: + return to_ctype_value(0., self.element_compute) + + @property + def argument_type(self): + self.get_reduce_identity() + stride_mnl = self.get_stride_mnl() + name = self.name + tuple_type = tuple_factory(stride_mnl, self.stride_dtype) + element_compute = self.element_compute + reduce_identity = self.get_reduce_identity() + class _Argument(ctypes.Structure): + _fields_ = [ + ("ptr", ctypes.c_void_p), + ("reduce_identity", dtype2ctype[element_compute]), + ("dMNL", tuple_type) + ] + def __init__(self, kwargs) -> None: + ptr = kwargs[name] + self.ptr = ptr + self.reduce_identity = reduce_identity + self.dMNL = tuple_type(stride_mnl) + + return _Argument + + +class ColumnReductionImpl(ReductionImplBase): + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if strideMN == (1, 0): + return True + else: + return False + + +class RowReductionImpl(ReductionImplBase): + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if strideMN == (0, 1): + return True + else: + return False + + +class ScalarReductionImpl(ReductionImplBase): + + @staticmethod + def match(node, problem_size: tuple): + if not node.is_output: + return False + if node.name in StoreImplBase.reserved_names: + return False + + strideMN = node.store_tensor.stride[-2:] + if strideMN == (0, 0): + return True + else: + return False + + +class StoreNode(NodeBase): + """ + Store node + """ + possible_impls = [ + AuxStoreImpl, RowReductionImpl, + ColumnReductionImpl, ScalarReductionImpl, + NoOpImpl, StoreDImpl + ] + def __init__(self, name: str) -> None: + super().__init__(name) + self.op = "store" + self.is_output = False + self._store_tensor = None + + @property + def store_tensor(self) -> Tensor: + """ + Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor) + """ + return self._store_tensor + + @store_tensor.setter + def store_tensor(self, kwargs): + """ + Setting the tensor + """ + self._store_tensor = Tensor(**kwargs) + + def type_propagation(self, input_node_metas: 'list[NodeBase]'): + """ + The store nodes has element_output = element_input + """ + if self.is_output: + if self.store_tensor is None: + raise RuntimeError(f"The store tensor of node {self.name} is unknown.") + self.element = self.store_tensor.element + assert len(input_node_metas) == 1, "Store node can only have one input node" + self.element_output = input_node_metas[0].element_output + + def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'): + super().broadcast_propagation(input_node_metas) + if self.is_output: + self._store_tensor.broadcast(self.tensor.shape) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..1a28b7306a140d08bd1edebd3486990ea69b9344 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py @@ -0,0 +1,137 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +High-level class for tensor +""" + +from cutlass_library import LayoutType + +from cutlass_cppgen.backend.evt.ir.layout_algorithm import ( + Layout, + broadcast, + canonicalization, + permutation, + reshape, + _reverse_tuple +) +from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type + + +class Tensor: + """ + The tensor abstracts the data type + """ + def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None: + if element is not None and tensor is not None: + raise Exception(f"Must not specify both element and tensor") + elif shape is not None and tensor is not None: + raise Exception(f"Must not specify both shape and tensor") + elif layout_tag is not None and tensor is not None: + raise Exception(f"Must not specify both layout_tag and tensor") + elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) : + raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)") + elif stride is not None and tensor is not None: + raise Exception(f"Must not specify both stride and tensor") + elif stride is not None and layout_tag is not None: + raise Exception(f"Must not specify layout_tag when stride is provided") + + if isinstance(tensor, Tensor): + # Directly copy all the attributes + self.__dict__.update(vars(tensor)) + else: + if tensor is None: + self.element = library_type(element) + else: + self.element, layout_tag = get_datatype_and_layout(tensor) + shape = get_tensor_shape(tensor) + if stride is not None: + self.layout = Layout(shape[::-1], stride[::-1]) + else: + if layout_tag == LayoutType.RowMajor: + self.layout = Layout(shape[::-1]) + elif layout_tag == LayoutType.ColumnMajor: + self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))]) + self.layout = canonicalization(self.layout) + + self.is_constant = is_constant + # Save the tensor value if it is constant + if is_constant and tensor is not None: + self.value = tensor + + @property + def shape(self): + """ + Returns the RowMajor layout shape + """ + return _reverse_tuple(self.layout.shape) + + @property + def stride(self): + """ + Returns the RowMajor layout stride + """ + return _reverse_tuple(self.layout.stride) + + @property + def rank(self): + """ + Returns the rank of the tensor + """ + return len(self.shape) + + # + # Layout Algorithms + # + + def broadcast(self, shape): + """ + Broadcast self.layout to shape + """ + assert isinstance(shape, tuple) + self.layout = broadcast(self.layout, _reverse_tuple(shape)) + + def reshape(self, shape): + """ + Reshape self.layout to shape + """ + assert isinstance(shape, tuple) + reverse_shape = _reverse_tuple(shape) + self.layout = reshape(self.layout, reverse_shape) + + def permute(self, indices): + """ + Permute self.layout according to indices + """ + length = len(indices) + indices = [length - idx - 1 for idx in indices] + self.layout = permutation(self.layout, indices[::-1]) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..badc38d96a830992c94afa693ea4b56a8e404c96 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer +from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType +from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD +from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager +from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py new file mode 100644 index 0000000000000000000000000000000000000000..8a28c6e4e62d1a7bd7431c81aac366b8788fd8df --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py @@ -0,0 +1,143 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +import subprocess + +from cutlass_library import DataTypeTag + +from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR + + +_COLOR_MAP = { + "load": '"AliceBlue"', + "compute": "LemonChiffon1", + "accumulator": "LightGrey", + "store": "PowderBlue", + "layout": "lightseagreen", + "dag": "darkorange" +} + + +class EVTGraphDrawer: + """ + Visualize a EVT DAGIR with graphviz + """ + def __init__( + self, + graph: DAGIR, + name: str + ): + self._name = name + self._dot_graphs = {} + + self._dot_graphs[name] = self._to_dot(graph, name) + + def _get_node_style(self, node): + template = { + "shape": "record", + "fillcolor": "#CAFFE3", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + if node.op in _COLOR_MAP: + template["fillcolor"] = _COLOR_MAP[node.op] + else: + raise NotImplementedError("unknown node op") + if node.disabled: + template["fontcolor"] = "grey" + template["fillcolor"] = "white" + return template + + def _get_node_label(self, node): + label = "{" + f"name={node.name}|op={node.op}" + if node.op == "layout": + label += f"|fn={node.fn.__name__}" + for key in node.kwargs: + label += f"|{key}={node.kwargs[key]}" + if node.underlying_impl is not None: + label += f"|impl={type(node.underlying_impl).__name__}" + if node.op == "load": + label += f"|element_output={DataTypeTag[node.underlying_impl.element]}" + elif node.op == "compute": + label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}" + elif node.op == "store": + label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}" + elif node.op == "dag": + label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}" + if node.tensor is not None: + shape = node.tensor.shape + stride = node.tensor.stride + label += f"|shape={shape}|stride={stride}" + + if hasattr(node, "store_tensor"): + if node.store_tensor is not None: + store_shape = node.store_tensor.shape + store_stride = node.store_tensor.stride + label += f"|store_shape={store_shape}|stride_stride={store_stride}" + + label += "}" + return label + + def _to_dot( + self, + graph: DAGIR, + name: str + ): + import pydot + dot_graph = pydot.Dot(name, randir="TB") + for node in graph.nodes_meta: + style = self._get_node_style(node) + label = self._get_node_label(node) + dot_node = pydot.Node( + node.name, label=label, **style + ) + dot_graph.add_node(dot_node) + if node.op == "dag": + dot_subgraph = self._to_dot(node.subgraph, name=node.name) + self._dot_graphs[node.name] = dot_subgraph + + # Add edges + for src, dst in graph.edges: + weight = graph.get_edge_weight(src, dst) + dot_graph.add_edge(pydot.Edge(src, dst, label=weight)) + + return dot_graph + + def get_dot_graph(self) -> pydot.Dot: + return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()] + + def get_dot_graph_by_name(self, name) -> pydot.Dot: + return self._dot_graphs[name] + + def get_main_dot_graph(self) -> pydot.Dot: + return self._dot_graphs[self._name] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c3cdbde6d46ad8a7e84c3b95422bdb55e877c5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py @@ -0,0 +1,120 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Construct the epilogue visitor argument type +""" + +from cutlass_cppgen.backend.c_types import visitor_factory +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode +from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class PassGetArgumentType(EVTPassBase): + """ + Construct the epilogue visitor argument type + """ + dependencies = [ + PassShapeTypePropagation, # The Layout of all nodes must be set + PassDAG2Tree, # The type of each node must be set + PassGetImpl # The DAG subgraphs must be set + ] + + def requires(self) -> None: + # Check "D" is in the node list + if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")): + raise SyntaxError( + "Sm90+ EVT requires the epilogue to have a returned tensor D, " + "but the variable 'D' is not found in the return values.") + + def call(self): + nodes = self.dag_ir.nodes_topological_order() + self.argument_types = {} + for node in nodes: + meta = self.dag_ir.get_node_meta(node) + if not meta.disabled: + self.argument_types[node] = meta.underlying_impl.argument_type + if node == "D" and cc_map[self.cc] in [90, 100]: + continue + if isinstance(meta, TopoVisitorNode): + self.get_dag_argument_type(node) + else: + self.get_evt_argument_type(node) + + self.cc_specific_method(self.set_argument_type)() + + def get_evt_argument_type(self, node): + # Sort the input nodes by edge weight + input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)] + if len(input_types) > 0: + self.argument_types[node] = visitor_factory( + input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,]) + + def get_dag_argument_type(self, node): + meta = self.dag_ir.get_node_meta(node) + subgraph = meta.subgraph + subgraph_nodes = subgraph.nodes_topological_order() + # Visit the unvisited nodes in subgraph + for n in subgraph_nodes: + m = subgraph.get_node_meta(n) + if m.disabled: + continue + else: + self.argument_types[n] = m.underlying_impl.argument_type + input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]] + if len(input_types) > 0: + self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1]) + + def set_argument_type(self): + pass + + def sm90_set_argument_type(self): + self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]] + # Get the tensorD argument type + self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d + + # Get the tensorC argument type + if self.dag_ir.has_node("C"): + self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c + else: + self.dag_ir.arg_c_type = self.dag_ir.arg_d_type + + def sm100_set_argument_type(self): + self.sm90_set_argument_type() + + def sm80_set_argument_type(self): + nodes = self.dag_ir.nodes_topological_order() + self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..469769664abdf757319949ab48b4e7d5e982f200 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py @@ -0,0 +1,169 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented +by the topological visitor, while the rest of the graph will be implemented with the tree visitor. +""" + +from copy import deepcopy + +from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode +from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation + + +class PassDAG2Tree(EVTPassBase): + """ + Convert the DAG IR to Tree by fusing subgraphs + """ + dependencies = [ + PassShapeTypePropagation, + PassGetImpl + ] + + def call(self): + # Step 1: find the nodes that have multiple parents + multi_parent_nodes = [] + + for node in self.dag_ir.nodes_topological_order(): + if self.dag_ir.out_degree(node) > 1: + multi_parent_nodes.append(node) + # Step 2: find the lowest common ancestor (LCA) of all its parents + for node in multi_parent_nodes: + # A multi-parent node could be already fused by the previous node + if not self.dag_ir.has_node(node): + continue + # A node uncovered by the previous fusions can have out degree change + # Case 1: it has <= 1 edges to the previously fused subgraph, no degree change + # Case 2: it has more than one edges to the previously fused subgraph, degree drops + if self.dag_ir.out_degree(node) <= 1: + continue + + # Otherwise, the node still + reachable_nodes = [] + # Complexity: O(Dout*N) + for parent in self.dag_ir.get_users(node): + reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent))) + # get the common reachable objects + common_items = set.intersection(*reachable_nodes) + node_to_fuse = set.union(*reachable_nodes).difference(common_items) + + lca = None + # If common ancestor exists, find the lowest one + if len(common_items) > 0: + topo_order = self.dag_ir.nodes_topological_order() + topo_idx = -1 + for item in common_items: + if lca is None: + lca = item + topo_idx = topo_order.index(item) + else: + if topo_idx > topo_order.index(item): + lca = item + topo_idx = topo_order.index(item) + else: + # there is no common ancestor for all the parents, we pack all the reachable + # nodes into a single DAG node as a fallback. The lca should be the input node of + # one of the output nodes with out_degree = 0 + potential_output_nodes = [] + for node in node_to_fuse: + if self.dag_ir.out_degree(node) == 0: + potential_output_nodes.append(node) + if len(potential_output_nodes) == 0: + raise RuntimeError(f"No output node with out degree = 0 found.") + + output_node = None + if (self.dag_ir.cc >= 90): + # For SM90+, the lca should be the input node of D + if (not self.dag_ir.has_node("D")): + raise RuntimeError(f"D is not a node in the DAG IR.") + output_node = "D" + else: + output_node = potential_output_nodes[0] + + if (output_node is None): + raise RuntimeError(f"No output node found.") + lca = self.dag_ir.get_all_inputs(output_node)[0] + node_to_fuse.remove(output_node) + + # The lca is the output node of the DAG node + # Get the nodes to be fused + node_to_fuse.add(lca) + # Get all the input nodes + all_input_nodes = [] + all_output_nodes = [] + for node in node_to_fuse: + all_input_nodes.append(set(self.dag_ir.get_all_inputs(node))) + all_output_nodes.append(set(self.dag_ir.get_users(node))) + all_input_nodes = set.union(*all_input_nodes) + all_output_nodes = set.union(*all_output_nodes) + + new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes) + + # Create the subgraph + subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes) + subgraph = DAGIR(self.dag_ir.cc) + for node in subgraph_.nodes: + meta = deepcopy(self.dag_ir.get_node_meta(node)) + if node not in node_to_fuse: + meta.disabled = True + subgraph.add_node(meta) + for edge in subgraph_.edges: + subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1])) + + + # Create the fused node + dag_node = TopoVisitorNode( + name=f"dag_{lca}", subgraph=subgraph, + output_node=self.dag_ir.get_node_meta(lca)) + self.dag_ir.add_node(dag_node) + + # Add input edges + for idx, node in enumerate(all_input_nodes): + self.dag_ir.add_edge(node, dag_node.name, weight=idx) + + # Replace all uses with DAG node (only 1 output node) + self.dag_ir.replace_all_uses_with(lca, dag_node.name) + + # Remove all fused nodes + node_to_fuse.remove(lca) + for node in node_to_fuse: + self.dag_ir.remove_node(node) + + def ensures(self) -> None: + # Ensure that after the pass, the resulting DAG becomes a tree + for node in self.dag_ir.nodes: + out_degree = self.dag_ir.out_degree(node) + if out_degree > 1: + raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}") diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py new file mode 100644 index 0000000000000000000000000000000000000000..0d57c5b799d125ccc9491760259569731c0bf3ca --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py @@ -0,0 +1,64 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Fix the element_output of producer of D. + +In Sm90 epilogue visitor, the node writing D to gmem does not have internal +element converter, so the compute node producing D must have element_output = type(D). +""" + +from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase + + +class PassFixElementD(EVTPassBase): + """ + In Sm90 epilogue visitor, the node writing D to gmem does not have internal + element converter, so the compute node producing D must have + element_output = type(D) + """ + dependencies = [ + PassLayoutManipulateElimination + ] + def get_producer(self, node, element_D): + node_meta = self.dag_ir.get_node_meta(node) + if node_meta.op == "compute": + node_meta.element_output = element_D + elif node_meta.op == "store": + self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D) + + def call(self): + if self.dag_ir.has_node("D"): + node_d_meta = self.dag_ir.get_node_meta("D") + element_D = node_d_meta.store_tensor.element + self.get_producer("D", element_D) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..90fdafe7d0e80492bd2e641c69f11d95aace6bba --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py @@ -0,0 +1,90 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Infer the underlying implement of each node. + +While the frontend only distinguish between Load/Store/Compute Node, +each of these nodes can have different underlying implementation based +on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc. +This pass infers the underlying impl of each node +""" + +import cutlass_cppgen.backend.evt.backend as evt_backend +from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode +from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class PassGetImpl(EVTPassBase): + """ + While the frontend only distinguish between Load/Store/Compute Node, + each of these nodes can have different underlying implementation based + on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc. + This pass infers the underlying impl of each node + """ + dependencies = [ + PassShapeTypePropagation, # The shape and type info are required for inference + PassFixElementD + ] + + def __init__(self, dag_ir: DAGIR) -> None: + super().__init__(dag_ir) + self.no_op_elimination = PassNoOpElimination(dag_ir) + + def requires(self) -> None: + # Verify "accum" is in the arg list + if not self.dag_ir.has_node("accum"): + raise SyntaxError("Cannot find 'accum' in the argument list.") + + def call(self): + # The loop structure of the epilogue is determined by the + # accumulator shape + accumulator: LoadNode = self.dag_ir.get_node_meta("accum") + problem_size = accumulator.tensor.shape + + for node_meta in self.dag_ir.node_metas_topological_order(): + node_meta.get_underlying_impl(problem_size) + + def ensures(self) -> None: + # Some nodes will be lowered to NoOp, eliminate them + self.no_op_elimination() + # Lower to cc-specific impl + for node_meta in self.dag_ir.nodes_meta: + node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes") + node_meta.underlying_impl = getattr( + node_impl_ccs, + f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__ + )(node_meta) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py new file mode 100644 index 0000000000000000000000000000000000000000..af147969f016b50ef05034fca99b173777948622 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py @@ -0,0 +1,217 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Eliminate layout manipulation nodes +""" + +from copy import deepcopy + +from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation + + +class PassLayoutManipulateElimination(EVTPassBase): + """ + Eliminate layout manipulation nodes + """ + dependencies = [PassShapeTypePropagation] + + def __init__(self, dag_ir: DAGIR) -> None: + super().__init__(dag_ir) + self.copy_cnt = 0 + + def call(self): + self.layout_nodes_worklist = self.get_all_layout_nodes() + # Run while loop utill all layout nodes are eliminated + while(len(self.layout_nodes_worklist) > 0): + node = self.layout_nodes_worklist.pop(0) + # for node in layout_nodes: + # Step 1: get the propagation direction + direction = self.get_propagation_direction(node) + self.visited = [] + getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node) + # Eliminate the current node + input_node = self.dag_ir.get_all_inputs(node)[0] + self.dag_ir.replace_all_uses_with(node, input_node) + # layout_nodes = self.get_all_layout_nodes() + + def get_all_layout_nodes(self): + layout_nodes = [] + for node_meta in reversed(self.dag_ir.node_metas_topological_order()): + if isinstance(node_meta, LayoutNode): + layout_nodes.append(node_meta.name) + return layout_nodes + + def get_propagation_direction(self, node: str): + """ + The logic is propagating all layout nodes away from the accumulator node. + """ + self.visited = [] + self.get_influenced_users(node) + nodes_influenced_dir_users = self.visited + self.visited = [] + self.get_influenced_inputs(node) + nodes_influenced_dir_inputs = self.visited + + if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs: + return "inputs" + elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs: + return "users" + else: + raise RuntimeError("Unsolved propagation direction") + + # Get all influenced nodes if we propagate along the user direction + def get_influenced_users(self, node: str): + if node in self.visited: + return + self.visited.append(node) + + users = self.dag_ir.get_users(node) + for user in users: + self.get_influenced_users(user) + user_inputs = [] + for user in users: + user_inputs.append(set(self.dag_ir.get_all_inputs(user))) + if len(user_inputs) > 0: + user_inputs = set.union(*user_inputs) + user_inputs.remove(node) + for input in user_inputs: + self.get_influenced_inputs(input) + + # Get all influenced nodes if we propagate along the input direction + def get_influenced_inputs(self, node: str): + if node in self.visited: + return + self.visited.append(node) + + inputs = self.dag_ir.get_all_inputs(node) + for input in inputs: + self.get_influenced_inputs(input) + input_users = [] + for input in inputs: + input_users.append(set(self.dag_ir.get_users(input))) + if len(input_users) > 0: + input_users = set.union(*input_users) + input_users.remove(node) + for user in input_users: + self.get_influenced_users(user) + + def add_copy_before(self, layout_node_meta: LayoutNode, target: str): + copied_node_meta = deepcopy(layout_node_meta) + copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}" + self.copy_cnt += 1 + copied_node_meta.name = copied_node + self.dag_ir.add_node(copied_node_meta) + # Add edges + target_inputs = self.dag_ir.get_all_inputs(target) + for src in target_inputs: + self.dag_ir.remove_edge(src, target) + self.dag_ir.add_edge(src, copied_node) + self.dag_ir.add_edge(copied_node, target) + self.layout_nodes_worklist.append(copied_node) + + def add_copy_after(self, layout_node_meta: LayoutNode, target: str): + copied_node_meta = deepcopy(layout_node_meta) + copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}" + self.copy_cnt += 1 + copied_node_meta.name = copied_node + self.dag_ir.add_node(copied_node_meta) + # Add edges + users = self.dag_ir.get_users(target) + for user in users: + self.dag_ir.remove_edge(target, user) + self.dag_ir.add_edge(copied_node, user) + self.dag_ir.add_edge(target, copied_node) + self.layout_nodes_worklist.append(copied_node) + + # Propagate the layout `node` along the user direction + def propagate_to_users(self, layout_node_meta: LayoutNode, node: str): + """ + Propagate layout node to users + """ + if node in self.visited: + # Avoid applying twice + return + self.visited.append(node) + + node_meta = self.dag_ir.get_node_meta(node) + if layout_node_meta.name != node: + if isinstance(node_meta, LayoutNode): + # Layout node is not transparent with layout node + self.add_copy_before(layout_node_meta, node) + return + else: + layout_node_meta.apply_to_user(node_meta) + + users = self.dag_ir.get_users(node) + user_inputs = [] + for user in users: + user_inputs.append(set(self.dag_ir.get_all_inputs(user))) + for user in users: + self.propagate_to_users(layout_node_meta, user) + if len(user_inputs) > 0: + user_inputs = set.union(*user_inputs) + user_inputs.remove(node) + for input in user_inputs: + self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input) + + # Propagate the layout `node` along the input direction + def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str): + """ + Propagate layout node to inputs + """ + if node in self.visited: + # Avoid applying twice + return + self.visited.append(node) + + node_meta = self.dag_ir.get_node_meta(node) + if layout_node_meta.name != node: + if isinstance(node_meta, LayoutNode): + # Layout node is not transparent with layout node + self.add_copy_after(layout_node_meta, node) + return + else: + layout_node_meta.apply_to_input(node_meta) + inputs = self.dag_ir.get_all_inputs(node) + input_users = [] + for input in inputs: + input_users.append(set(self.dag_ir.get_users(input))) + for input in inputs: + self.propagate_to_inputs(layout_node_meta, input) + if len(input_users) > 0: + input_users = set.union(*input_users) + input_users.remove(node) + for user in input_users: + self.propagate_to_users(layout_node_meta.get_inverse_node(), user) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b46bddb06e7c20be6d20526792777edef64b90 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py @@ -0,0 +1,164 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Pass manager for DAG IR. +""" + +from typing import Any + +import networkx as nx + +from cutlass_cppgen.backend.evt.ir import DAGIR +from cutlass_cppgen.backend.evt.passes.util import cc_map + + +class EVTPassBase: + """ + Base class for EVT Passes + """ + dependencies = [] + def __init__(self, dag_ir: DAGIR) -> None: + self.dag_ir = dag_ir + self.cc = self.dag_ir.cc + + def requires(self) -> None: + """ + This function will be called before the pass is run. + """ + pass + + def call(self) -> None: + """ + The pass that is run through the self.dag_ir + """ + raise NotImplementedError( + f"__call__ is not overwritten in Pass {self.__class__.__name__}") + + def ensures(self) -> None: + """ + This function will be called after the pass is run. + """ + pass + + def __call__(self) -> Any: + self.requires() + self.call() + self.ensures() + + def cc_specific_method(self, func): + """ + This enables defining function that behaves differently under different cc + The simplest example of using this function is the following + + .. highlight:: python + .. code-block:: python + + class ExamplePass(EVTPassBase): + + def call(sekf): + # This automatically select the smXX_func based on current cc + self.cc_specific_method(self.func)() + + # Interface func, can be empty + def func(self): + pass + + # Sm90 specific func + def sm90_func(self): + // sm90 specific method + return + + # Sm80 specific func + def sm80_func(self): + // sm80 specific method + return + """ + func_name = f"sm{cc_map[self.cc]}_{func.__name__}" + if hasattr(self, func_name): + return getattr(self, func_name) + else: + raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}") + + +class EVTPassManager(nx.DiGraph): + """ + Topological-based Pass Manager. + Each registered pass has a list of dependencies. The pass manager organizes + the passes as a DAG and launch the compiler passes under topological order. + """ + def __init__(self, dag_ir: DAGIR, pass_list): + super().__init__() + self.dag_ir = dag_ir + for pass_cls in pass_list: + self.add_pass(pass_cls) + + self.sorted_passes = self.schedule() + + def get_callable(self, pass_name): + """ + Return the callable of the pass + """ + return self.nodes[pass_name]["callable"] + + def add_pass(self, pass_cls): + """ + Add a pass to the pass manager + :param pass_cls: the class of pass + :type pass_cls: derived class of EVTPassBase + """ + name = pass_cls.__name__ + pass_callable = pass_cls(self.dag_ir) + self.add_node(name, callable=pass_callable) + + def schedule(self): + """ + Schedule the added passes under topological order + """ + # Add edges + for pass_name in self.nodes: + callable = self.get_callable(pass_name) + for dependency_cls in callable.dependencies: + self.add_edge( + dependency_cls.__name__, + type(callable).__name__) + + # Topological sort + return list(nx.topological_sort(self)) + + def __call__(self) -> Any: + """ + Launch the registered passes + """ + for pass_name in self.sorted_passes: + callable = self.get_callable(pass_name) + callable() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py new file mode 100644 index 0000000000000000000000000000000000000000..13107eb1d11c9a436348a4e50a92e62ce6f8b312 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py @@ -0,0 +1,53 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +No op elimination node +""" + +from typing import Any + +from cutlass_cppgen.backend.evt.ir import NoOpImpl +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase + + +class PassNoOpElimination(EVTPassBase): + """ + The dead node elimination pass removes nodes with NoOpImpl in DAG IR + """ + dependencies = [] + + def call(self) -> Any: + for node in self.dag_ir.nodes_topological_order(): + node_meta = self.dag_ir.get_node_meta(node) + if isinstance(node_meta.underlying_impl, NoOpImpl): + self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0]) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py new file mode 100644 index 0000000000000000000000000000000000000000..6423a2b845dd643650cf99037178030bee6f0dbd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py @@ -0,0 +1,97 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Preprocess the reduction nodes. + +The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store() +This pass fuses these into a single store node, and then replaces all uses of the +current node with the new store node. +""" + +from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase + + +class PassPreprocessRed(EVTPassBase): + """ + Preprocess red nodes + """ + + def call(self): + # Step 1: find the compute nodes with op=red + red_compute_nodes = [] + for node_meta in self.dag_ir.nodes_meta: + if isinstance(node_meta, ComputeNode): + if type(node_meta.fn) == tuple: + # To keep the frontend simple, the reduction nodes + # are parsed into compute nodes by default + # The simple heuristic to distinguish between compute + # and reduction node is that compute node is a single function, + # while the reduction node is a tuple of functions for + # in-register reduction and atomic global memory reduction + red_compute_nodes.append(node_meta.name) + + # Step 2: for each compute, merge it with the succeeding store + for node in red_compute_nodes: + # Verify + users = self.dag_ir.get_users(node) + inputs = self.dag_ir.get_all_inputs(node) + # Has a single user + assert len(users) == 1 + assert len(inputs) == 1 + user = users[0] + input = inputs[0] + + user_meta = self.dag_ir.get_node_meta(user) + # Must be a store node + assert isinstance(user_meta, StoreNode) + # With output degree == 0 + assert self.dag_ir.out_degree(user) == 0 + # Register the reduce op + node_meta = self.dag_ir.get_node_meta(node) + user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn + user_meta.element_compute = node_meta.element_compute + user_meta.round_style = node_meta.round_style + + # Replace all uses + self.dag_ir.remove_edge(input, node) + input_users = self.dag_ir.get_users(input) + for iu in input_users: + weight = self.dag_ir.get_edge_weight(input, iu) + self.dag_ir.add_edge(user, iu, weight) + self.dag_ir.remove_edge(input, iu) + self.dag_ir.add_edge(input, user) + self.dag_ir.remove_node(node) + + # Register the reduction name + self.dag_ir.reduction_names.append(user) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py new file mode 100644 index 0000000000000000000000000000000000000000..cb90a82c8f637429d3c64b3d881eb30d02c8c804 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py @@ -0,0 +1,59 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Shape and type propagation pass +""" + +from cutlass_cppgen.backend.evt.ir.node import NodeBase +from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase +from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed + + +class PassShapeTypePropagation(EVTPassBase): + """ + Propagate the shape and type of all nodes + """ + dependencies = [PassPreprocessRed] + + def call(self): + # Propagate the node shape and type + for node in self.dag_ir.nodes_topological_order(): + node_meta: NodeBase = self.dag_ir.get_node_meta(node) + input_node_metas = self.dag_ir.get_all_inputs_meta(node) + node_meta.type_propagation(input_node_metas) + node_meta.shape_propagation(input_node_metas) + + for node in reversed(self.dag_ir.nodes_topological_order()): + node_meta: NodeBase = self.dag_ir.get_node_meta(node) + input_node_metas = self.dag_ir.get_all_inputs_meta(node) + node_meta.broadcast_propagation(input_node_metas) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..8168c59733a5da15eacbbe583c890610655ecff5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py @@ -0,0 +1,319 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Compute the shared memory size in bytes +""" + +from math import gcd + +import cutlass_library +from pycute import flatten, shape_div, product + +import cutlass_cppgen +from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR +from cutlass_cppgen.backend.library import DataType, DataTypeSize + + +class GetSmemSize: + """ + Get the size in byte of shared memory used by the kernel + """ + def __init__(self, dag_ir: DAGIR) -> None: + self.dag_ir = dag_ir + self.cc = self.dag_ir.cc + + # + # Sm90 epilogue specific + # + + def sm90_epilogue_tile(self, tile_description): + # Get the epilogue tile size + schedule = tile_description.epilogue_schedule + if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized: + element_d = self.dag_ir.get_node_meta("D").element + nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32 + epi_tile_m = min(64, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) + elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative: + epi_tile_m = min(128, tile_description.threadblock_shape[0]) + epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1]) + epilogue_tile_mn = (epi_tile_m, epi_tile_n) + else: + raise NotImplementedError(f"Unsupported schedule: {schedule}") + + # Get the pipeline stages + stages_d = 2 + epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) + if self.dag_ir.has_node("C"): + element_c = self.dag_ir.get_node_meta("C").element + else: + element_c = None + + element_d = self.dag_ir.get_node_meta("D").element + if element_c == element_d: + reuse_smem_c = True + else: + reuse_smem_c = False + stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles + + # Record the epilogue tile + self.cta_tile_mnk = tuple(tile_description.threadblock_shape) + self.epilogue_tile_mn = epilogue_tile_mn + self.epi_tiles = epi_tiles + self.stages_c = stages_c + self.stages_d = stages_d + self.reuse_smem_c = reuse_smem_c + self.element_c = element_c + self.element_d = element_d + self.is_source_supported = element_c is not None + + def sm90_or_sm100_epilogue_smem_size(self, tile_description): + # Get the Fusion Storage + nodes = self.dag_ir.nodes_topological_order() + self.smem_types = {} + for node in nodes: + meta = self.dag_ir.get_node_meta(node) + if not meta.disabled: + self.smem_types[node] = meta.underlying_impl.get_smem_size( + self.cta_tile_mnk, self.epilogue_tile_mn, + self.stages_c, self.stages_d, self.epi_tiles) + if node == "D": + continue + if isinstance(meta, TopoVisitorNode): + self.get_dag_smem_type(node) + else: + self.get_evt_smem_type(node) + + thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0] + # Get the Tensor Storage + tensors = [] + if self.is_source_supported: + smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8 + tensors.append((smem_C, 128)) + else: + tensors.append((0, 1)) + if self.reuse_smem_c: + tensors.append((0, 128)) + else: + smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8 + tensors.append((smem_D, 128)) + tensors.append((thread_smem_size, 128)) + + tensor_smem_size = self.get_struct_size(tensors) + # Get pipeline storage size + # sizeof(uint64_t * stages_c * 2), alignment of uint64_t + # 2 is for FullBarrier and EmptyBarrier + pipeline_smem_size = (8 * self.stages_c * 2, 8) + + # get SharedStorage size + smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size]) + return smem_size[0] + + def sm90_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm90 collective epilogue + """ + self.sm90_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + + # + # Sm100 epilogue specific + # + + def sm100_epilogue_tile(self, tile_description): + cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1]) + mma_tile = cta_tile + + if tile_description.is_2sm: + cta_tile = (cta_tile[0] // 2, cta_tile[1]) + + if tile_description.is_2sm and mma_tile[0] == 128: + tmem_warps = (2, 2) + else: + tmem_warps = (4, 1) + + if self.dag_ir.has_node("C"): + element_c = self.dag_ir.get_node_meta("C").element + element_c_size = DataTypeSize[element_c] + else: + element_c = None + element_c_size = 0 + + element_d = self.dag_ir.get_node_meta("D").element + + DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void + + CtaM = cta_tile[0] + CtaN = cta_tile[1] + WarpM = tmem_warps[0] + WarpN = tmem_warps[1] + MaxBits = max(element_c_size, DataTypeSize[element_d]) + DpFull = 32 + M = min(CtaM, DpFull * WarpM) + + if DisableSource: + # Epilogues w/o residual load are less sensitive to smem allocation + # Target a fixed amount of compute per epilogue iteration + if MaxBits == 4: + # Make epilogue tile larger to reduce the epilogue iterations. + # 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same. + ComputeElts = 8192 + Nperf = ComputeElts // M + else: + ComputeElts = 4096 + Nperf = ComputeElts // M + else: + # Epilogues w/ residual load are more sensitive to smem allocation + # Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize + if MaxBits == 32: + Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32 + elif MaxBits == 16: + Nperf = 32 if CtaN <= 128 else 64 + else: + Nperf = 64 + + def is_m_major(layout): + return flatten(layout.stride[0]) == 1 + + if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout): + N_min_C = 8 * WarpN + elif element_c_size == 6: + N_min_C = 128 * WarpN + else: + N_min_C = (128 // element_c_size) * WarpN + + if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout): + N_min_D = 8 * WarpN + elif DataTypeSize[element_d] == 6: + N_min_D = 128 * WarpN + else: + N_min_D = (128 // DataTypeSize[element_d]) * WarpN + + N = min(CtaN, max(Nperf, N_min_C, N_min_D)) + + tile_m = M + tile_n_size = N // WarpN * WarpN + + epilogue_tile_mn = (tile_m, tile_n_size) + epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn)) + + stages_d = min(epi_tiles, 2) + reuse_smem_c = (element_c_size > 8) + + if reuse_smem_c: + stages_c = max(min(epi_tiles, 4), stages_d + 1) + else: + stages_c = min(epi_tiles, 4) + + # Record the epilogue tile + self.cta_tile_mnk = tuple(tile_description.threadblock_shape) + self.epilogue_tile_mn = epilogue_tile_mn + self.epi_tiles = epi_tiles + self.stages_c = stages_c + self.stages_d = stages_d + self.reuse_smem_c = reuse_smem_c + self.element_c = element_c + self.element_d = element_d + self.is_source_supported = not DisableSource + + def sm100_epilogue_smem_size(self, tile_description): + """ + Compute the shared memory size of sm100 collective epilogue + """ + self.sm100_epilogue_tile(tile_description) + return self.sm90_or_sm100_epilogue_smem_size(tile_description) + + def __call__(self, tile_description): + return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description) + + # + # Helper functions + # + + @staticmethod + def get_visitor_size(members: list, ebo: bool): + """ + Get the size of struct in bytes + """ + offset = 0 + max_alignment = 1 + if len(members) > 0: + # Get alignment + for _, alignment in members: + max_alignment = max(max_alignment, alignment) + + for type_size, _ in members: + if type_size != 0: + offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment + if type_size == 0 and not ebo: + offset += 1 + else: + offset += type_size + offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment + return (offset, max_alignment) + else: + # Struct size is at least 1 + return (1, 1) + + def get_struct_size(self, members: list): + """ + Get the size of struct in bytes + """ + return self.get_visitor_size(members, False) + + def get_evt_smem_type(self, node): + # Sort the input nodes by edge weight + input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)] + input_types.append(self.smem_types[node]) + if len(input_types) > 1: + ebo = len(input_types) > 4 + self.smem_types[node] = self.get_visitor_size(input_types, ebo) + + def get_dag_smem_type(self, node): + meta = self.dag_ir.get_node_meta(node) + subgraph = meta.subgraph + subgraph_nodes = subgraph.nodes_topological_order() + # Visit the unvisited nodes in subgraph + for n in subgraph_nodes: + m = subgraph.get_node_meta(n) + if m.disabled: + continue + else: + self.smem_types[n] = m.underlying_impl.get_smem_size( + self.cta_tile_mnk, self.epilogue_tile_mn, + self.stages_c, self.stages_d, self.epi_tiles) + input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]] + if len(input_types) > 0: + ebo = len(input_types) > 4 + self.smem_types[node] = self.get_visitor_size(input_types, ebo) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py new file mode 100644 index 0000000000000000000000000000000000000000..4b72e330523ca1e4fb8c5d4526289641e158e72e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for passes +""" + +# Map from the CC of the kernel to the EVT implementation that the CC targets +cc_map = { + 80: 80, + 86: 80, + 89: 80, + 90: 90, + 100: 100, + 101: 100, + 103: 100, +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..a959976b8601b0793c4c7c1709d61c8c838df838 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py @@ -0,0 +1,109 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +import numpy as np + +from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor + + +class NumpyFrontend: + """ + Frontend node for numpy + """ + + @staticmethod + def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr: + """Convert the input numpy tensor to CUDA device pointer + + :param np_tensor: input numpy nd array + :param is_output: whether the tensor is output + + :return: CUDA device pointer + """ + # copy the data to device + if is_output: + return device_mem_alloc(np_tensor.size * np_tensor.itemsize) + else: + return todevice(np_tensor) + + +class TorchFrontend: + """ + Frontend node for torch + """ + + @staticmethod + def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr: + """Convert the input torch tensor to CUDA device pointer + + :param torch_tensor: input torch tensor + :param is_output: whether the tensor is output + + :return: CUDA device pointer + """ + + # check the device of torch_tensor + if not torch_tensor.is_cuda: + torch_tensor = torch_tensor.to("cuda") + + return cuda.CUdeviceptr(torch_tensor.data_ptr()) + + +class CupyFrontend: + """ + Frontend node for cupy + """ + + @staticmethod + def argument(cupy_ndarray: "cp.ndarray"): + return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr)) + + +class TensorFrontend: + """ + Universal Frontend for client-provide tensors + """ + + @staticmethod + def argument(tensor, is_output=False): + if is_numpy_tensor(tensor): + return NumpyFrontend.argument(tensor, is_output) + elif is_torch_tensor(tensor): + return TorchFrontend.argument(tensor) + elif is_cupy_tensor(tensor): + return CupyFrontend.argument(tensor) + else: + raise NotImplementedError("Unknown Tensor Type") diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..5e2a3a30a097eb45c691554daf70f8db12e5bc48 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py @@ -0,0 +1,2145 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +from __future__ import annotations + +import copy +import ctypes +import enum + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +from cutlass_library import SubstituteTemplate +import numpy as np + +from cutlass_library import ( + ComplexTransformTag, + DataType, + DataTypeNames, + DataTypeSize, + DataTypeTag, + EpilogueScheduleSuffixes, + EpilogueScheduleTag, + EpilogueScheduleType, + GemmKind, + GemmKindNames, + GemmUniversalMode, + KernelScheduleSuffixes, + KernelScheduleTag, + KernelScheduleType, + LayoutTag, + LayoutType, + MathOperation, + MathOperationTag, + OpcodeClass, + OpcodeClassNames, + OpcodeClassTag, + OperationKind, + ShortComplexLayoutNames, + ShortDataTypeNames, + ShortLayoutTypeNames, + SwizzlingFunctor, + SwizzlingFunctorTag, + TileSchedulerSuffixes, + TileSchedulerTag, + TileSchedulerType, + get_complex_from_real +) +from cutlass_cppgen.backend.arguments import ArgumentBase +from cutlass_cppgen.backend.c_types import ( + GemmCoord_, + GemmCoordBatched_, + GenericMainloopArguments3x_, + StrideBatched_, + dim3_, + get_gemm_arguments, + get_gemm_arguments_3x, + get_gemm_arguments_streamk, + get_gemm_grouped_arguments, + get_mainloop_arguments_3x, + get_tile_scheduler_arguments_3x, +) +from cutlass_cppgen.backend.library import ( + ApiVersion, + EmissionType, + SchedulerMode, + SchedulerModeTag, + TensorDescription, + TileDescription, + api_version, +) +from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor +from cutlass_cppgen.backend.utils.device import device_sm_count +from cutlass_cppgen.shape import GemmCoord, MatrixCoord + + +################################################################################ +# +# Data structure modeling a GEMM operation +# +################################################################################ + + +def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int: + """ + Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``. + + :param layout: layout of the tensor + :type layout: cutlass_cppgen.shape.LayoutType + :param shape: shape of the tensor + :type shape: cutlass_cppgen.shape.MatrixCoord + + :return: leading dimension of the tensor + :rtype: int + """ + if layout == LayoutType.RowMajor: + return shape.column + elif layout == LayoutType.ColumnMajor: + return shape.row + + +def transpose_layout(layout: LayoutType) -> LayoutType: + if layout == LayoutType.ColumnMajor: + return LayoutType.RowMajor + elif layout == LayoutType.RowMajor: + return LayoutType.ColumnMajor + else: + raise ValueError(f"Unsupported Layout {layout}") + + +class GemmArguments2x(ArgumentBase): + """ + Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + """ + + def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + self.operation = operation + + self.layout_A = operation.A.layout + self.layout_B = operation.B.layout + self.layout_C = operation.C.layout + + self.element_A = operation.A.element + self.element_B = operation.B.element + self.element_C = operation.C.element + + if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]: + raise Exception("Interleaved layout not currently supported") + + if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]: + super().__init__(A, B, None, None, **kwargs) + else: + super().__init__(A, B, C, D, **kwargs) + + if operation.switched: + self.problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k) + self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A + else: + self.problem_size = problem_size + # If the number of elements in C = problem_size.n, C is treated as the bias + if hasattr(self, "tensor_c_numel"): + if self.tensor_c_numel == self.problem_size.n and self.problem_size.m != 1: + self.bias = True + + self.lda = leading_dimension(self.layout_A, self.problem_size.mk) + self.ldb = leading_dimension(self.layout_B, self.problem_size.kn) + self.ldc = leading_dimension(self.layout_C, self.problem_size.mn) + self.ldd = self.ldc + + if self.bias: + self.ldc = 0 + + if "output_op" in kwargs.keys() and gemm_mode != GemmUniversalMode.GemmSplitKParallel: + self.output_op = kwargs["output_op"] + else: + if self.operation.epilogue_functor.element_epilogue in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: + dtype = int + else: + dtype = float + self.output_op = self.operation.epilogue_type(dtype(1.0), dtype(0.0)) + + self.gemm_mode = gemm_mode + if gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]: + if "split_k_slices" in kwargs.keys(): + self.batch_count = kwargs["split_k_slices"] + else: + self.batch_count = 1 + self.split_k_slices = self.batch_count + + if gemm_mode in [GemmUniversalMode.Batched, GemmUniversalMode.Array]: + if "batch" in kwargs.keys(): + self.batch_count = kwargs["batch"] + else: + self.batch_count = 1 + + if "batch_strides" in kwargs: + self.batched_stride_A = kwargs["batch_strides"]["A"] + self.batched_stride_B = kwargs["batch_strides"]["B"] + self.batched_stride_C = kwargs["batch_strides"]["C"] + self.batched_stride_D = kwargs["batch_strides"]["D"] + else: + self.batched_stride_A = self.problem_size.m * self.problem_size.k + self.batched_stride_B = self.problem_size.n * self.problem_size.k + self.batched_stride_C = self.problem_size.m * self.problem_size.n + self.batched_stride_D = self.problem_size.m * self.problem_size.n + + if self.bias: + self.batched_stride_C = self.problem_size.n + + if gemm_mode == GemmUniversalMode.Array: + self.ptr_A_array = [] + self.ptr_B_array = [] + self.ptr_C_array = [] + self.ptr_D_array = [] + + ptr_A_addr = int(self.ptr_A) + ptr_B_addr = int(self.ptr_B) + ptr_C_addr = int(self.ptr_C) + ptr_D_addr = int(self.ptr_D) + + stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8 + stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8 + stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8 + stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8 + for _ in range(self.batch_count): + self.ptr_A_array.append(ptr_A_addr) + self.ptr_B_array.append(ptr_B_addr) + self.ptr_C_array.append(ptr_C_addr) + self.ptr_D_array.append(ptr_D_addr) + + ptr_A_addr += stride_A + ptr_B_addr += stride_B + ptr_C_addr += stride_C + ptr_D_addr += stride_D + + self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64) + self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64) + self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64) + self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64) + + if isinstance(self.operation, GemmOperationUniversal): + self.initialize() + + def get_arguments(self): + problem_size_ = self.problem_size.ctype + grid_tiled_shape_ = GemmCoord( + self.grid_tiled_shape.x, + self.grid_tiled_shape.y, + self.grid_tiled_shape.z ).ctype + + if self.gemm_mode == GemmUniversalMode.Array: + arguments = self.operation.argument_type( + # Arguments from UniversalArgumentsBase + self.gemm_mode, + problem_size_, + self.batch_count, + 0, + # Remaining arguments + self.output_op, + int(self.ptr_A_array_buffer.ptr), + int(self.ptr_B_array_buffer.ptr), + int(self.ptr_C_array_buffer.ptr), + int(self.ptr_D_array_buffer.ptr), + 0, 0, 0, + self.lda, self.ldb, self.ldc, self.ldd, + self.lda, self.ldb, self.ldc, self.ldd, + 0, 0, 0 + ) + else: + arguments = self.operation.argument_type( + # Arguments from UniversalArgumentsBase + self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D, + # Remaining arguments + self.output_op, + int(self.ptr_A), + int(self.ptr_B), + int(self.ptr_C), + int(self.ptr_D), + self.batched_stride_A, + self.batched_stride_B, + self.batched_stride_C, + self.lda, self.ldb, self.ldc, self.ldd, + self.lda, self.ldb, self.ldc, self.ldd, + 0, 0, 0 + ) + + self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size + + def initialize(self): + launch_config = self.operation.rt_module.plan(self) + + # Get the host and device workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + # In GEMM splik-K parallel, the D pointer is redirected to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: + device_workspace = workspace_ptr + + self.get_arguments() + + arguments, grid_tiled_shape, gemm_k_size = self.arguments + res_arg = self.operation.rt_module.get_args( + ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace))) + host_workspace = bytearray(res_arg.contents) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = launch_config + + def sync(self, stream_sync=True): + super().sync(stream_sync) + if hasattr(self.output_op, "sync"): + self.output_op.sync() + + +class GemmArguments2xStreamK(GemmArguments2x): + """ + Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + """ + + def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]: + raise Exception(f"Unsupported GEMM mode {gemm_mode}.") + + super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + def get_arguments(self): + batch_stride_A = self.problem_size.m * self.problem_size.k + batch_stride_B = self.problem_size.k * self.problem_size.n + batch_stride_C = self.problem_size.m * self.problem_size.n + batch_stride_D = self.problem_size.m * self.problem_size.n + + arguments = self.operation.argument_type( + self.gemm_mode, + GemmCoord_(self.problem_size.m, self.problem_size.n, self.problem_size.k), + self.batch_count, + self.output_op, + int(self.ptr_A), + int(self.ptr_B), + int(self.ptr_C), + int(self.ptr_D), + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D, + self.lda, self.ldb, self.ldc, self.ldd, # strides + self.lda, self.ldb, self.ldc, self.ldd, + -1, # avail_sms + ) + return arguments + + def initialize(self): + # Get the host and device workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size( + self, + device_sm_count(), + self.operation.rt_module.occupancy + ) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + # In GEMM splik-K parallel, the D pointer is redirected to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: + device_workspace = workspace_ptr + + arguments = self.get_arguments() + + res_arg = self.operation.rt_module.get_args( + ctypes.byref(arguments), + ctypes.c_void_p(int(device_workspace)), + device_sm_count(), + self.operation.rt_module.occupancy + ) + host_workspace = bytearray(res_arg.contents) + + grid = self.operation.rt_module.get_grid_shape( + ctypes.byref(arguments), + device_sm_count(), + self.operation.rt_module.occupancy + ) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = LaunchConfiguration( + [grid.m, grid.n, grid.k], + [self.operation.rt_module.threads, 1, 1], + self.operation.rt_module.shared_memory_capacity + ) + + +class GemmArguments3x(GemmArguments2x): + """ + Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: GemmUniversalMode + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + """ + + def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]: + raise Exception(f"Unsupported GEMM mode {gemm_mode}.") + + super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + def get_arguments(self): + mainloop_args = get_mainloop_arguments_3x( + self.operation.tile_description.kernel_schedule, + self.operation.A.element, + self.operation.B.element, + self.operation.A.alignment, + self.operation.B.alignment + ) + scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler) + uses_default_epilogue = self.operation.rt_module.uses_default_epilogue() + argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x( + mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue) + + problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count) + + if self.batch_count > 1: + bsA = self.batched_stride_A + bsB = self.batched_stride_B + bsC = self.batched_stride_C + bsD = self.batched_stride_D + else: + bsA = 0 + bsB = 0 + bsC = 0 + bsD = 0 + stride_A = StrideBatched_(self.lda, bsA) + stride_B = StrideBatched_(self.ldb, bsB) + stride_C = StrideBatched_(self.ldc, bsC) + stride_D = StrideBatched_(self.ldd, bsD) + + # Superset of potential mainloop arguments + generic_args = GenericMainloopArguments3x_( + int(self.ptr_A), + stride_A, + int(self.ptr_B), + stride_B, + 4 # mma_promotion_interval + ) + + # Set of mainloop arguments needed for this kernel + mainloop = mainloop_args.from_generic_mainloop_args(generic_args) + + if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"): + self.output_op = self.output_op.to_evt_params() + + epilogue = epilogue_args( + self.output_op, + int(self.ptr_C), + stride_C, + int(self.ptr_D), + stride_D, + ) + + # Set hardware info + hw_info_ = hw_info( + 0, device_sm_count(), 0, + dim3_(0,0,0), + dim3_(0,0,0), + ) + + self.arguments = argument_type( + int(self.gemm_mode), + problem_size_, + mainloop, + epilogue, + hw_info_, + scheduler_args + ) + return self.arguments + + def initialize(self): + # Get the host and evice workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + device_workspace = 0 + if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + # In GEMM splik-K parallel, the D pointer is redirected to the workspace + self.ptr_D = cuda.CUdeviceptr(workspace_ptr) + elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm: + device_workspace = workspace_ptr + + self.get_arguments() + res_arg = self.operation.rt_module.get_args( + ctypes.byref(self.arguments), + ctypes.c_void_p(int(device_workspace)), + ) + host_workspace = bytearray(res_arg.contents) + + grid = self.operation.rt_module.get_grid_shape( + ctypes.byref(self.arguments), + ctypes.c_void_p(int(device_workspace)), + ) + block = self.operation.rt_module.get_block_shape() + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = LaunchConfiguration( + [grid.x, grid.y, grid.z], + [block.x, block.y, block.z], + self.operation.rt_module.shared_memory_capacity, + ) + + +def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): + """ + Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments + or 3x arguments depending on the `arch` field specified in `operation`. + + :param operation: the GEMM operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` | + :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: GEMM problem size gemm(M, N, K) + :type operation: :class:`cutlass_cppgen.shape.GemmCoord` + + :param A: tensor A + :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param B: tensor B + :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param C: tensor C + :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param D: tensor D + :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray + + :param gemm_mode: GEMM mode + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + """ + if operation.swizzling_functor == SwizzlingFunctor.StreamK: + if operation.api == ApiVersion.v3x: + raise Exception("Stream K is currently only supported in CUTLASS 2.x") + ArgClass = GemmArguments2xStreamK + else: + ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x + return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs) + + +class GemmGroupedArguments: + """ + Argument wrapper for GEMM Grouped. It encodes problem information and + user-provide tensors into the kernel's argument + + :param operation: the GEMM Grouped operation to take the argument + :type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped` + + :param problem_size: list of GEMM problem size gemm(M, N, K) + :type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`] + + :param A: list of tensor A + :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param B: list of tensor B + :type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param C: list of tensor C + :type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param D: list of tensor D + :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray] + + :param output_op: output operator, optional + :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments` + + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + """ + + def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs): + # Get number of problems in the group + self.problem_count = len(problem_sizes) + + # Check the input arguments + assert len(A) == self.problem_count + assert len(B) == self.problem_count + assert len(C) == self.problem_count + assert len(D) == self.problem_count + + problem_size_host = [] + self.ptr_A_host = [] + self.ptr_B_host = [] + self.ptr_C_host = [] + self.ptr_D_host = [] + + lda_host = [] + ldb_host = [] + ldc_host = [] + ldd_host = [] + + self.partitions = 1 + + self.operation = operation + + # Get the threadblock + threadblock_shape = operation.tile_description.threadblock_shape + self.threadblock_shape = GemmCoord( + threadblock_shape[0], + threadblock_shape[1], + threadblock_shape[2], + ) + self.threadblock_swizzle = operation.swizzling_functor + + self.total_tiles = 0 + + self.gemm_arguments = [] + + self.stream = kwargs.get("stream", cuda.CUstream(0)) + + # Process the input arguments + for idx, problem_size in enumerate(problem_sizes): + M, N, K = problem_size.m, problem_size.n, problem_size.k + temp_argument = GemmArguments2x( + operation=operation, + problem_size=GemmCoord(M, N, K), + A=A[idx], B=B[idx], C=C[idx], D=D[idx]) + self.gemm_arguments.append(temp_argument) + + problem_size_host.append( + [temp_argument.problem_size.m, + temp_argument.problem_size.n, + temp_argument.problem_size.k] + ) + + self.ptr_A_host.append(int(temp_argument.ptr_A)) + lda_host.append(temp_argument.lda) + + self.ptr_B_host.append(int(temp_argument.ptr_B)) + ldb_host.append(temp_argument.ldb) + + self.ptr_C_host.append(int(temp_argument.ptr_C)) + ldc_host.append(temp_argument.ldc) + + self.ptr_D_host.append(int(temp_argument.ptr_D)) + ldd_host.append(temp_argument.ldd) + + # Get number of tiles + grid = self.operation.rt_module.get_grid_shape( + self.operation.rt_module.get_tiled_shape( + temp_argument.problem_size.ctype, + self.threadblock_shape.ctype, + temp_argument.batch_count + ) + ) + self.total_tiles += grid.x * grid.y * grid.z + + self.problem_size_buffer = todevice(problem_size_host, np.int32) + self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64) + self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64) + self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64) + self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64) + + self.lda_buffer = todevice(lda_host, np.int64) + self.ldb_buffer = todevice(ldb_host, np.int64) + self.ldc_buffer = todevice(ldc_host, np.int64) + self.ldd_buffer = todevice(ldd_host, np.int64) + + if "output_op" in kwargs.keys(): + self.alpha = kwargs["output_op"].alpha + self.beta = kwargs["output_op"].beta + else: + self.alpha = 1.0 + self.beta = 0.0 + + if "output_op" in kwargs.keys(): + self.output_op = kwargs["output_op"] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) + + # Get host problem size + self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0] + + self.arguments = self.get_arguments() + + self.initialize() + + def get_arguments(self): + return self.operation.argument_type( + self.problem_size_buffer.ptr, + self.problem_count, + self.total_tiles, + self.output_op, + self.ptr_A_buffer.ptr, + self.ptr_B_buffer.ptr, + self.ptr_C_buffer.ptr, + self.ptr_D_buffer.ptr, + self.lda_buffer.ptr, + self.ldb_buffer.ptr, + self.ldc_buffer.ptr, + self.ldd_buffer.ptr, + ctypes.c_void_p(int(self.host_problem_size_ptr)), + ) + + def initialize(self): + # Get launch configuration + launch_config = self.operation.rt_module.plan(self) + + # Get the host and evice workspace + device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + + if device_workspace_size > 0: + self.workspace_buffer = device_mem_alloc(device_workspace_size) + workspace_ptr = self.workspace_buffer.ptr + err, = cuda.cuMemsetD32( + workspace_ptr, 0, device_workspace_size // 4) + else: + workspace_ptr = None + + if self.operation.precompute_mode == SchedulerMode.Host: + device_workspace_ptr = self.operation.rt_module.host_precompute( + self, self.operation.rt_module.get_workspace_size(self),) + else: + device_workspace_ptr = 0 + + result = self.operation.rt_module.get_args( + ctypes.byref(self.arguments), + self.total_tiles, + ctypes.c_void_p(int(device_workspace_ptr)), + ) + host_workspace = bytearray(result.contents) + + device_workspace = None + + self.host_workspace = host_workspace + self.device_workspace = device_workspace + self.launch_config = launch_config + + def sync(self): + err, = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + for arg in self.gemm_arguments: + arg.sync(stream_sync=False) + + +################################################################################ +# Base class for GEMM runtime module +################################################################################ + + +class GemmRTbase(ExecutableOperation): + """ + GemmRT manages the CUTLASS runtime components + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix}::invoke(params, *shared_storage); +} + """ + + def __init__(self, operation: "GemmOperation"): + super().__init__(operation) + + self.operation = operation + threadblock_shape = operation.tile_description.threadblock_shape + self.threadblock_shape = GemmCoord( + threadblock_shape[0], threadblock_shape[1], threadblock_shape[2]) + self.threadblock_swizzle = operation.swizzling_functor + + # Threads per threadblock + self.threads = operation.tile_description.num_threads + + def emit(self): + return self.emitter.emit(self.operation) + + def can_implement(self, configuration, arguments): + raise NotImplementedError() + + def get_host_workspace_size(self, arguments): + raise NotImplementedError() + + def get_device_workspace_size(self, arguments): + return 0 + + def initialize(self): + err, = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError( + f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}" + ) + + +################################################################################ +# Runtime module for GEMM Universal +################################################################################ + + +class GemmRTUniversal(GemmRTbase): + """ + GemmRTUniversal manages the CUTLASS runtime components + """ + + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){ + ${operation_name}_base::Params* params; + params = new ${operation_name}_base::Params(*argument, + -1, // SM count. Only used for stream-K + -1 // Occupancy. Only used for stream-K + ); + + // Semaphore holds the pointer to the workspace in the Params struct + params->semaphore = workspace; + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}_base::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape( + cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) { + return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape( + problem_size, tile_size, split_k_slices); + } + + dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) { + return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape); + } +} + """ + + def __init__(self, operation): + super(GemmRTUniversal, self).__init__(operation) + self.extra_funcs = { + "get_tiled_shape": GemmCoord_, + "get_grid_shape": dim3_, + } + self.emitter = EmitGemmUniversalInstance( + "_type", operation.direct_store) + + self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor) + self.argtype = [ + ctypes.POINTER(self.argument_type), + ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p + ] + + def plan(self, arguments): + grid = self.get_tiled_shape( + arguments.problem_size.ctype, + self.threadblock_shape.ctype, + arguments.batch_count + ) + + gemm_k_size = arguments.problem_size.k + if arguments.gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]: + alignk = max(max(128 // DataTypeSize[self.operation.A.element], + 128 // DataTypeSize[self.operation.B.element]), 1) + + gemm_k_size = (((arguments.problem_size.k + arguments.batch_count - 1) // + arguments.batch_count + alignk - 1) // alignk) * alignk + + if gemm_k_size: + grid_z = (arguments.problem_size.k + gemm_k_size - 1) // gemm_k_size + grid = GemmCoord(grid.m, grid.n, grid_z).ctype + + arguments.grid_tiled_shape = dim3_(grid.m, grid.n, grid.k) + grid = self.get_grid_shape(grid) + arguments.gemm_k_size = gemm_k_size + return LaunchConfiguration( + [grid.x, grid.y, grid.z], + [self.threads, 1, 1], + self.shared_memory_capacity) + + def get_device_workspace_size(self, arguments: GemmArguments): + workspace_bytes = 0 + if arguments.gemm_mode == GemmUniversalMode.GemmSplitKParallel: + workspace_bytes = (DataTypeSize[arguments.operation.C.element] + * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8) + elif (arguments.gemm_mode == GemmUniversalMode.Gemm and + arguments.split_k_slices > 1): + workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y + + return workspace_bytes + + +class GemmRTUniversalStreamK(GemmRTUniversal): + """ + Manages the CUTLASS runtime components for 2.x stream K kernels + """ + + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + using GemmType = ${operation_name}_base; + + // Get the params as byte array + char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace, + int sm_count, int occupancy) { + GemmType::Params* params; + params = new GemmType::Params(*argument, sm_count, occupancy); + + params->init_workspace(workspace); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(GemmType::Params)]; + for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) { + typename GemmType::Params params(*args, device_sms, sm_occupancy); + return params.get_grid_dims(); + } + + uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) { + typename GemmType::Params params(*args, device_sms, sm_occupancy); + return params.get_workspace_size(); + } +} + """ + + def __init__(self, operation: "GemmOperation"): + super(GemmRTUniversalStreamK, self).__init__(operation) + self.extra_funcs = { + "get_grid_shape": GemmCoord_, + "get_kernel_workspace_size": ctypes.c_uint64, + } + self._occupancy = None + self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor) + + @property + def occupancy(self): + if self._occupancy is None: + err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + self.kernel, self.threads, self.shared_memory_capacity, + cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError( + "CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: " + f"{cuda.cuGetErrorString(err)[1]}") + return self._occupancy + + def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int): + return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy) + + +################################################################################ +# Runtime module for GEMM Universal within CUTLASS 3 +################################################################################ + + +class GemmRTUniversal3x(GemmRTUniversal): + """ + Manages the CUTLASS runtime components for 3.x kernels + """ + + KernelTemplate = r""" + +using Operator = ${operation_name}${operation_suffix}; +extern "C" +__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +void ${operation_name}(__grid_constant__ typename Operator::Params const params) { + // Dynamic shared memory base pointer + extern __shared__ char smem[]; + + // Declare pointer to dynamic shared memory. + Operator op; + op(params, smem); +} + """ + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return ${operation_name}${operation_suffix}::SharedStorageSize; + } + + using GemmType = ${operation_name}_base; + + bool ${operation_name}_uses_default_epilogue() { + return std::is_same_v; + } + + // Get the workspace size + uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) { + return GemmType::get_workspace_size(*argument); + } + + // Get the params as byte array + char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){ + GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace); + char *bytes = ((char*)(¶ms)); + char *output = new char[sizeof(GemmType::Params)]; + for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + // Get the total number of blocks for a persistent kernel + uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) { + auto problem_shape_MNKL = append<4>(problem, Int<1>{}); + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = + cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( + problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{}); + return problem_blocks_m * problem_blocks_n * problem_blocks_l; + } + + // Get the grid shape + dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) { + auto tmp_params = GemmType::to_underlying_arguments(*args, workspace); + return GemmType::get_grid_shape(tmp_params); + } + + // Get the block shape + dim3 ${operation_name}_get_block_shape() { + return GemmType::get_block_shape(); + } +} + """ + + def __init__(self, operation): + super(GemmRTUniversal3x, self).__init__(operation) + self.extra_funcs = { + "get_grid_shape": dim3_, + "get_block_shape": dim3_, + "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64, + "get_kernel_workspace_size": ctypes.c_uint64, + "uses_default_epilogue": ctypes.c_bool, + } + self.emitter = EmitGemmUniversalInstance3x("_type") + + def get_device_workspace_size(self, arguments: GemmArguments3x): + return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments())) + + +class EmitGemmUniversalInstance3x: + """Responsible for emitting a CUTLASS 3 template definition""" + + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cute/tensor.hpp", + "cute/atom/mma_atom.hpp", + "cutlass/numeric_types.h", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/gemm/kernel/sm90_tile_scheduler.hpp", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + "cutlass/epilogue/collective/default_epilogue.hpp", + "cutlass/epilogue/thread/linear_combination.h" + ] + self.gemm_template_kernel = """ +using namespace cute; + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule} + >::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stage_count_type}, + ${kernel_schedule} + >::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + ${tile_scheduler} +>; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_kernel_visitor = """ +using namespace cute; + +${callback_decl} + +using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + cute::Shape, + cute::Shape, + cutlass::epilogue::collective::EpilogueTileAuto, + ${element_accumulator}, ${element_epilogue}, + ElementC, StrideC, ${align_c}, + ElementD, StrideD, ${align_d}, + ${epilogue_schedule}, + ${callback_name} + >::CollectiveOp; + +using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape, + ${stage_count_type}, + ${kernel_schedule} + >::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + ${tile_scheduler} +>; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + self.gemm_template_device = self.gemm_template_kernel + """ + +// Define device-level operator +using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>; +""" + + def emit(self, operation): + # Support built-in epilogue functors or user-defined functions + + if operation.tile_description.stages is None or operation.tile_description.stages == 0: + stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>" + else: + stage_count_type = "_" + str(operation.tile_description.stages) + + if operation.emission_type == EmissionType.Kernel: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + kschedule = KernelScheduleType.ScheduleAuto + eschedule = EpilogueScheduleType.ScheduleAuto + tschedule = TileSchedulerType.Default + if operation.tile_description.kernel_schedule is not None: + kschedule = operation.tile_description.kernel_schedule + if operation.tile_description.epilogue_schedule is not None: + eschedule = operation.tile_description.epilogue_schedule + if operation.tile_description.tile_scheduler is not None: + tschedule = operation.tile_description.tile_scheduler + + emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_d": DataTypeTag[operation.epilogue_functor.element_output], + "layout_d": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(emit_tile_m), + "threadblock_shape_n": str(emit_tile_n), + "threadblock_shape_k": str(emit_tile_k), + "cluster_m": str(operation.tile_description.cluster_shape[0]), + "cluster_n": str(operation.tile_description.cluster_shape[1]), + "cluster_k": str(operation.tile_description.cluster_shape[2]), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "stage_count_type": stage_count_type, + "kernel_schedule": KernelScheduleTag[kschedule], + "epilogue_schedule": EpilogueScheduleTag[eschedule], + "tile_scheduler": TileSchedulerTag[tschedule] + } + if hasattr(operation.epilogue_functor, "visitor"): + callback_name, callback_decl = operation.epilogue_functor.emit(operation) + values["callback_name"] = callback_name + values["callback_decl"] = callback_decl + return SubstituteTemplate(self.gemm_template_kernel_visitor, values) + + else: + values["epilogue_functor"] = operation.epilogue_functor.emit() + return SubstituteTemplate(gemm_template, values) + + +################################################################################################### +# Runtime module for GEMM Grouped +################################################################################################### + + +class GemmRTGrouped(GemmRTbase): + """ + GemmRTGrouped manages the CUTLASS runtime components + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + """ + + HostTemplate = r""" + extern "C" { + + // precompute scheduling information + char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) { + char* host_workspace = new char[workspace_bytes]; + ${operation_name}_base::ProblemVisitor::host_precompute( + args.host_problem_sizes, + args.problem_count, + args.threadblock_count, + (void*)host_workspace + ); + return host_workspace; + } + + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){ + ${operation_name}_base::Params* params; + params = new ${operation_name}_base::Params(*argument, workspace, tile_count); + + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}_base::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++) + output[i] = bytes[i]; + + return output; + } + + cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape( + cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) { + return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape( + problem_size, tile_size, split_k_slices); + } + + dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) { + return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape); + } + } + """ + + def __init__(self, operation: "GemmOperation"): + super(GemmRTGrouped, self).__init__(operation) + self.extra_funcs = { + "precompute": None, + "get_tiled_shape": GemmCoord_, + "get_grid_shape": dim3_, + } + self.emitter = EmitGemmGroupedInstance("_type") + self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p] + + def host_precompute(self, arguments, workspace_bytes): + self.precompute.argtype = [ + self.argtype[0], ctypes.c_int, ctypes.c_longlong] + self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes) + + problem_info = self.precompute( + ctypes.byref(arguments.arguments), + arguments.total_tiles, + workspace_bytes) + problem_info_array = bytearray(problem_info.contents) + + # copy to device memory + return todevice(problem_info_array).ptr + + def plan(self, arguments): + return LaunchConfiguration( + [arguments.total_tiles, 1, 1], + [self.threads, 1, 1], + self.shared_memory_capacity, + ) + + def get_workspace_size(self, arguments): + if self.operation.precompute_mode == SchedulerMode.Device: + return 0 + elif self.operation.precompute_mode == SchedulerMode.Host: + total_tiles = arguments.total_tiles + entries_per_block = 1 + return 8 * entries_per_block * total_tiles # three int32_t + + +################################################################################ +# Runtime module for GEMM and grouped GEMM +################################################################################ + + +class GemmOperationBase: + """ + CUTLASS GEMM operation + """ + + def __init__( + self, gemm_kind, arch, tile_description: TileDescription, + A: TensorDescription, B: TensorDescription, C: TensorDescription, + epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, + api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs): + self.operation_kind: OperationKind = OperationKind.Gemm + self.arch: int = arch + self.tile_description: TileDescription = tile_description + self.gemm_kind: GemmKind = gemm_kind + + self.api = api + self.prefix = "3x" if self.api == ApiVersion.v3x else "" + self.emission_type = emission_type + + # Optionally swap the TensorDescriptions for operands A and B and transpose their + # layouts. This is needed to mimic the transpose performed by device::GemmUniversal. + # The code below uses deep copy to avoid overwritting the original TensorDescription + self.switched = (self.api != ApiVersion.v3x and + self.emission_type == EmissionType.Kernel and + C.layout == LayoutType.ColumnMajor) + + self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched) + + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + if "direct_store" in kwargs: + self.direct_store = kwargs["direct_store"] + else: + self.direct_store = False + + @staticmethod + def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool): + """ + Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set, + A and B are swapped, and the layout of A, B, and C are transposed. + + :param A: description of operand A + :type A: TensorDescription + :param B: description of operand B + :type B: TensorDescription + :param C: description of operand C + :type C: TensorDescription + + :return: descriptions of operands A, B, and C + :rtype: tuple[TileDescription] + """ + if swap: + A_out = copy.deepcopy(B) + B_out = copy.deepcopy(A) + C_out = copy.deepcopy(C) + A_out.layout = transpose_layout(A_out.layout) + B_out.layout = transpose_layout(B_out.layout) + C_out.layout = transpose_layout(C_out.layout) + else: + A_out = copy.deepcopy(A) + B_out = copy.deepcopy(B) + C_out = copy.deepcopy(C) + return A_out, B_out, C_out + + def run(self, arguments: GemmArguments) -> cuda.CUresult: + """ + Configure and launch the cuda kernel with input arguments + """ + if self.emission_type == EmissionType.Device: + raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"') + + err = self.rt_module.run( + arguments.host_workspace, + arguments.device_workspace, + arguments.launch_config, + arguments.stream + ) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + return err + + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32, + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + """The basic operation kind is prefixed with a letter indicating the accumulation type.""" + + inst_shape = "" + inst_operation = "" + intermediate_type = "" + + math_operations_map = { + MathOperation.xor_popc: "xor", + } + + if (self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp): + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else "" + + if self.tile_description.math_instruction.instruction_shape is not None: + if self.api == ApiVersion.v3x and self.arch >= 90: + inst_shape = "%dx%dx%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + else: + inst_shape = "%d%d%d" % tuple( + self.tile_description.math_instruction.instruction_shape) + else: + inst_shape = "Default" + inst_shape += math_op_string + + if (self.tile_description.math_instruction.element_a != self.A.element and + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) + + def extended_name(self): + """Append data types if they differ from compute type.""" + if self.is_complex(): + extended_name = "${core_name}" + else: + if (self.C.element != self.tile_description.math_instruction.element_accumulator and + self.A.element != self.tile_description.math_instruction.element_accumulator): + extended_name = "${element_c}_${core_name}_${element_a}" + elif (self.C.element == self.tile_description.math_instruction.element_accumulator and + self.A.element != self.tile_description.math_instruction.element_accumulator): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }) + + return extended_name + + def extended_name_3x(self): + """Generates a string representing the MMA atom. Assumes accumulator type is C type.""" + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a=DataTypeNames[self.A.element], + element_b=DataTypeNames[self.B.element], + element_acc=DataTypeNames[self.accumulator_type()], + element_c=DataTypeNames[self.C.element], + element_d=DataTypeNames[self.epilogue_functor.element_output], + core_name=self.core_name()) + return extended_name + + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) + def layout_name_3x(self): + if self.is_complex() or self.is_planar_complex(): + return "{}{}{}".format( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) + else: + return "{}{}{}".format( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ShortLayoutTypeNames[self.C.layout]) + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name_3x(self): + if self.tile_description.kernel_schedule is None: + return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto] + else: + return KernelScheduleSuffixes[self.tile_description.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name_3x(self): + if self.tile_description.epilogue_schedule is None: + return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto] + else: + return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule] + + def procedural_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + if self.api == ApiVersion.v3x and self.arch >= 90: + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" + return kernel_name_template.format( + p=self.prefix, + ar=self.arch, + op=opcode_class_name, + ex=self.extended_name_3x(), + tbm=self.tile_description.threadblock_shape[0], + tbn=self.tile_description.threadblock_shape[1], + tbk=self.tile_description.threadblock_shape[2], + cm=self.tile_description.cluster_shape[0], + cn=self.tile_description.cluster_shape[1], + ck=self.tile_description.cluster_shape[2], + l=self.tile_description.stages, + s=self.layout_name_3x(), + al=str(self.A.alignment), + k=self.kernel_schedule_name_3x(), + e=self.epilogue_schedule_name_3x() + ) + else: + threadblock = self.tile_description.procedural_name_2x() + return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( + p=self.prefix, + op=opcode_class_name, + ex=self.extended_name(), + tb=threadblock, + l=self.layout_name(), + a=str(self.A.alignment) + ) + + def configuration_name(self): + """The full procedural name indicates architecture, extended name, tile size, and layout.""" + return self.procedural_name() + + +class GemmOperationUniversal(GemmOperationBase): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, + epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs): + api = api_version(arch, tile_description.math_instruction.opcode_class, A.element) + super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description, + A, B, C, epilogue_functor, swizzling_functor, + api=api, **kwargs, ) + if api == ApiVersion.v3x: + if swizzling_functor == SwizzlingFunctor.StreamK: + raise Exception("Stream K swizzle functor is currently only supported for CUTLASS 2.x kernels") + self.rt_module = GemmRTUniversal3x(self) + else: + if swizzling_functor == SwizzlingFunctor.StreamK: + self.rt_module = GemmRTUniversalStreamK(self) + else: + self.rt_module = GemmRTUniversal(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def device_op(self): + """ + Returns a new GemmOperationUniversal object that is constructed with emission type + ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, + any swappng performed by the kernel-emitted operation is reversed. + + :return: operation ready for device-level code emission + :rtype: GemmUniversalOperation + """ + A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) + return GemmOperationUniversal(self.arch, self.tile_description, A, B, C, + self.epilogue_functor, self.swizzling_functor, + emission_type=EmissionType.Device, direct_store=self.direct_store) + + +class GemmOperationGrouped(GemmOperationBase): + def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C, + epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs): + super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description, + A, B, C, epilogue_functor, swizzling_functor, **kwargs) + assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'." + self.precompute_mode = kwargs["precompute_mode"] + self.rt_module = GemmRTGrouped(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def device_op(self): + """ + Returns a new GemmOperationGrouped object that is constructed with emission type + ``EmissionType.Device``. Since the device-emitted kernel does not require swapping, + any swappng performed by the kernel-emitted operation is reversed. + + :return: operation ready for device-level code emission + :rtype: GemmOperationGrouped + """ + A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched) + return GemmOperationGrouped( + self.arch, self.tile_description, A, B, C, self.epilogue_functor, + self.swizzling_functor, emission_type=EmissionType.Device, + direct_store=self.direct_store, precompute_mode=self.precompute_mode, ) + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + + +class EmitGemmUniversalInstance: + """Responsible for emitting a CUTLASS template definition""" + + def __init__( + self, + operation_suffix="", + direct_store=False + ): + self.operation_suffix = operation_suffix + self.direct_store = direct_store + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm_coord.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + ] + if self.direct_store: + self.includes.append( + "cutlass/epilogue/threadblock/default_epilogue_direct_store.h" + ) + self.gemm_template_kernel = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + self.gemm_template_device = """ +// Gemm operator ${operation_name} +using DeviceKernel = + typename cutlass::gemm::device::GemmUniversal< + // Data type and layout of operand A + ${element_a}, ${layout_a}, + // Data type and layout of operand B + ${element_b}, ${layout_b}, + // Data type and layout of operand C + ${element_c}, ${layout_c}, + // Data type of accumulator + ${element_accumulator}, + // Class of operation + ${opcode_class}, + // Compute capability of the target kernel + ${arch}, + // Threadblock tile shape + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + // Warp tile shape + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + // Instruction shape + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + // Epilogue functor + ${epilogue_functor}, + // Swizzling function + ${swizzling_functor}, + // Number of pipeline stages + ${stages}, + // Alignment of operands A and B + ${align_a}, ${align_b}, + // Type of math operation + ${math_operation}, + // Complex transform types of operands A and B + ${transform_a}, ${transform_b} + >; +""" + self.gemm_template_direct_store = """ +// Gemm operator ${operation_name} +using ${operation_name}_default = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +using ${operation_name}_base = + cutlass::gemm::kernel::GemmUniversal< + ${operation_name}_default::Mma, + cutlass::epilogue::threadblock::DefaultEpilogueDirectStore< + ${operation_name}_default::Epilogue + >::Epilogue, + ${operation_name}_default::ThreadblockSwizzle + >; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_kernel_visitor = """ + +using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + ${element_c}, + ${align_c}, + ${epilogue_stages} /* epilogue stages */ +>; + +${callback_decl} + +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_accumulator}, + ${element_epilogue}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${callback_name}, + ${swizzling_functor}, + ${stages}, + ${math_operation}, + ${epilogue_stages} /* epilogue stages */ +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + def emit(self, operation): + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + if operation.emission_type == EmissionType.Kernel: + if self.direct_store: + gemm_template = self.gemm_template_direct_store + else: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], + } + + if hasattr(operation.epilogue_functor, "visitor"): + self.includes += [ + "cutlass/epilogue/threadblock/fusion/visitors.hpp", + "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" + ] + callback_name, callback_decl = operation.epilogue_functor.emit(operation) + values["callback_name"] = callback_name + values["callback_decl"] = callback_decl + values["align_c"] = str(operation.C.alignment) + values["element_epilogue"] = DataTypeTag[operation.epilogue_functor.element_epilogue] + if hasattr(operation.epilogue_functor, "epilogue_stages"): + epilogue_stages = operation.epilogue_functor.epilogue_stages + else: + epilogue_stages = 1 + values["epilogue_stages"] = str(epilogue_stages) + return SubstituteTemplate(self.gemm_template_kernel_visitor, values) + else: + values["epilogue_functor"] = operation.epilogue_functor.emit() + return SubstituteTemplate(gemm_template, values) + + +class EmitGemmGroupedInstance: + """Responsible for emitting a CUTLASS template definition""" + + def __init__(self, operation_suffix=""): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/kernel/gemm_grouped.h", + "cutlass/gemm/kernel/default_gemm_grouped.h", + ] + self.gemm_template_kernel = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmGrouped< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${precompute_mode}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_device = ( + self.gemm_template_kernel + + """ +using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>; +""" + ) + + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmGrouped<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + def emit(self, operation): + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + # Support built-in epilogue functors or user-defined functions + epilogue_functor = operation.epilogue_functor.emit() + + values = { + "operation_name": operation.procedural_name(), + "operation_suffix": self.operation_suffix, + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[instance_layout_A], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[instance_layout_B], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[instance_layout_C], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]), + "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]), + "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]), + "epilogue_functor": epilogue_functor, + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "precompute_mode": SchedulerModeTag[operation.precompute_mode], + "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation], + } + + if operation.emission_type == EmissionType.Kernel: + gemm_template = self.gemm_template_kernel + else: + gemm_template = self.gemm_template_device + + return SubstituteTemplate(gemm_template, values) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py new file mode 100644 index 0000000000000000000000000000000000000000..a77b302dcccf330cc0e0f9b3f1290ab7030c5932 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py @@ -0,0 +1,509 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Common data types and string names/tags for them +""" + +import enum + +from cutlass_library import ( + ComplexTransform, + DataType, + DataTypeSize, + EpilogueScheduleType, + KernelScheduleSuffixes, + KernelScheduleType, + MathOperation, + OpcodeClass, + TileSchedulerType +) + + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + + +class DataTypeSizeBytes: + """ + Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the + data type key is less than a full byte or a non-integer number of bytes. + """ + + @staticmethod + def __class_getitem__(datatype): + """ + Returns the number of bytes in size the data type is. Raises an exception if the data type + is either less than a full byte or a non-integer number of bytes in size. + + :param datatype: data type to query + + :return: number of bytes the data type occupies + :rtype: int + """ + bits = DataTypeSize[datatype] + if bits < 8: + raise Exception( + f"Data type {datatype} is less than one byte in size." + ) + elif bits % 8 != 0: + raise Exception( + f"Data type datatype is not an integer number of bytes." + ) + return bits // 8 + + +class SchedulerMode(enum.Enum): + Device = enum_auto() + Host = enum_auto() + + +SchedulerModeTag = { + SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly", + SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute", +} + + +ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"} + + +class FunctionalOp(enum.Enum): + AtomicAdd = enum_auto() + AtomicMaximum = enum_auto() + Divides = enum_auto() + Maximum = enum_auto() + Minimum = enum_auto() + Minus = enum_auto() + Multiplies = enum_auto() + MultiplyAdd = enum_auto() + Plus = enum_auto() + Exp = enum_auto() + + +FunctionalOpTag = { + FunctionalOp.AtomicAdd: "cutlass::atomic_add", + FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum", + FunctionalOp.Divides: "cutlass::divides", + FunctionalOp.Maximum: "cutlass::maximum", + FunctionalOp.Minimum: "cutlass::minimum", + FunctionalOp.Minus: "cutlass::minus", + FunctionalOp.Multiplies: "cutlass::multiplies", + FunctionalOp.MultiplyAdd: "cutlass::multiply_add", + FunctionalOp.Plus: "cutlass::plus", + FunctionalOp.Exp: "cutlass::fast_exp_op", +} + + +class ActivationOp(enum.Enum): + DGelu = enum_auto() + Gelu = enum_auto() + GeluTaylor = enum_auto() + HardSwish = enum_auto() + Identity = enum_auto() + LeakyReLU = enum_auto() + ReLU = enum_auto() + Sigmoid = enum_auto() + SiLU = enum_auto() + Tanh = enum_auto() + + +ActivationOpTag = { + ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU", + ActivationOp.Gelu: "cutlass::epilogue::thread::GELU", + ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor", + ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish", + ActivationOp.Identity: "cutlass::epilogue::thread::Identity", + ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU", + ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu", + ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid", + ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu", + ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh", +} + + +def op_tag(op) -> str: + """ + Dispatches `op` to the appropriate *Tag dictionary depending on whether + `op` is an ActivationOp or FunctionalOp. This is useful for cases in which + either type can be used. + + :param op: operation to emit a tag for + :type op: ActivationOp | FunctionalOp + + :return: tag corresponding to op + :rtype: str + """ + if isinstance(op, ActivationOp): + return ActivationOpTag[op] + elif isinstance(op, FunctionalOp): + return FunctionalOpTag[op] + else: + raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.") + + +class FloatRoundStyle(enum.Enum): + ToNearest = enum_auto() + ToNearestSatfinite = enum_auto() + Indeterminate = enum_auto() + TowardZero = enum_auto() + TowardInfinity = enum_auto() + TowardNegInfinity = enum_auto() + HalfUlpTruncDntz = enum_auto() + HalfUlpTruncate = enum_auto() + + +FloatRoundStyleTag = { + FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest", + FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite", + FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate", + FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero", + FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity", + FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity", + FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz", + FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate", +} + + +class MathInstruction: + """ + Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel + """ + + def __init__( + self, + instruction_shape, + element_a, + element_b, + element_accumulator, + opcode_class=OpcodeClass.Simt, + math_operation=MathOperation.multiply_add, + ): + """ + :param instruction_shape: size of the [M, N, K] dimensions of the instruction + :type instruction_shape: list or tuple + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_accumulator: data type used in accumulation + :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core) + :type opcode_class: cutlass_library.library.OpcodeClass + :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate) + :type math_operation: MathOperation + """ + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + + +def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule): + blackwell_threadblock_shape = tile_description.threadblock_shape + is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule]) + if cluster_shape[0] > 0: + blackwell_threadblock_shape = [ + tile_description.threadblock_shape[0] // cluster_shape[0], + tile_description.threadblock_shape[1] // cluster_shape[1], + tile_description.threadblock_shape[2] // cluster_shape[2] + ] + if is_2sm: + blackwell_threadblock_shape[0] *= 2 + else: + blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape + return blackwell_threadblock_shape, is_2sm + + +class TileDescription: + """ + Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes, + stage count, and math instruction specification + """ + + def __init__( + self, + threadblock_shape, + stages, + warp_count, + math_instruction, + cluster_shape=[1, 1, 1], + kernel_schedule: KernelScheduleType = None, + epilogue_schedule: EpilogueScheduleType = None, + tile_scheduler: TileSchedulerType = None + ): + """ + :param threadblock_shape: shape of a threadblock tyle + :type threadblock_shape: list or tuple + :param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum + number of stages that can be supported for an operation on a given architecture will be computed at a later time + :type stages: int or None + :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile + :type warp_count: list, tuple, or None + :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands + :type math_instruction: MathInstruction + :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster + :param kernel_schedule: type of kernel schedule to use (only available for SM90+) + :type kernel_schedule: cutlass_library.KernelScheduleType + :param epilogue_schedule: type of epilogue schedule to use (only available for SM90+) + :type epilogue_schedule: cutlass_library.EpilogueScheduleType + :param tile_scheduler: type of tile scheduler to use (only available for SM90+) + :type tile_scheduler: cutlass_library.TileSchedulerType + """ + if ((kernel_schedule is None and epilogue_schedule is not None) or + (kernel_schedule is not None and epilogue_schedule is None)): + raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.") + + self.threadblock_shape = threadblock_shape + self.cluster_shape = cluster_shape + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + self.tile_scheduler = tile_scheduler + self.stages = stages + + self.math_instruction = math_instruction + self.instruction_shape = math_instruction.instruction_shape + + # Number of warps along x, y, z directions + self.warp_count = warp_count + + self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule) + + def clone_and_update(self, td: dict): + attrs = { + "cluster_shape": None, + "threadblock_shape": None, + "warp_count": None, + "stages": None, + "instruction_shape": None, + "kernel_schedule": None, + "epilogue_schedule": None, + "tile_scheduler": None + } + for key in attrs.keys(): + if key in td.keys(): + attrs[key] = td[key] + else: + attrs[key] = getattr(self, key) + + attrs["math_instruction"] = MathInstruction( + attrs["instruction_shape"], + self.math_instruction.element_a, + self.math_instruction.element_b, + self.math_instruction.element_accumulator, + self.math_instruction.opcode_class, + self.math_instruction.math_operation + ) + + # Remove the instruction shape + del attrs["instruction_shape"] + + return TileDescription(**attrs) + + @property + def num_threads(self): + """ + Returns the number of threads in the threadblock + + :return: number of threads in the threadblock + :rtype: int or None (if warp count is None) + """ + if self.warp_count is not None: + threads = 32 + for cnt in self.warp_count: + threads *= cnt + return threads + return None + + def procedural_name(self): + """ + Returns a name identifying the tile description + + :return: name identifying the tile description + :rtype: int + """ + emit_stages = 0 if self.stages is None else self.stages + name = "%dx%dx%d_%dx%d_%dx%d" % ( + self.cluster_shape[0], + self.cluster_shape[1], + self.cluster_shape[2], + self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + emit_stages + ) + + return name + + def procedural_name_2x(self): + """ + Returns a name identifying the tile description + + :return: name identifying the tile description + :rtype: int + """ + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + + def __str__(self): + """ + Returns a string with containing each of the tile description's values + + :return: contents of tile description + :rtype: str + """ + if self.kernel_schedule is not None: + kschedule = self.kernel_schedule + else: + kschedule = KernelScheduleType.ScheduleAuto + + if self.epilogue_schedule is not None: + eschedule = self.epilogue_schedule + else: + eschedule = EpilogueScheduleType.ScheduleAuto + + if self.tile_scheduler is not None: + tschedule = self.tile_scheduler.name + else: + tschedule = "None" + return f""" +{{ + ClusterShape: {self.cluster_shape} + ThreadblockShape: {self.threadblock_shape} + WarpCount: {self.warp_count} + Stages: {self.stages if self.stages is not None else 'Auto'} + InstructionShape: {self.math_instruction.instruction_shape} + Kernel schedule: {kschedule.name} + Epilogue schedule: {kschedule.name} + TileScheduler: {tschedule} +}}""" + + +class TensorDescription: + def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none): + self.element = element + self.layout = layout + if element != DataType.void: + self.alignment = min(128 // DataTypeSize[self.element], alignment) + else: + self.alignment = alignment + self.complex_transform = complex_transform + + +def CalculateSmemUsagePerStage(operation): + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: cutlass_cppgen.backend.Operation + + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = operation.tile_description.threadblock_shape + + if operation.operation_kind == OperationKind.Gemm: + stage_barrier_bytes = 32 + return ( + (DataTypeSize[operation.A.element] * m * k // 8) + + (DataTypeSize[operation.B.element] * k * n // 8) + + stage_barrier_bytes + ) + else: + raise Exception("Unsupported operation kind {}.".format(operation.operation_kind)) + + +def CalculateSmemUsage(operation): + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :param op: operation for which the maximum stages should be computed. If stages are + set via the `op.tile_description.stages` parameter, this setting is ignored + in the present calculation + :type op: cutlass_cppgen.backend.Operation + + :return: int + """ + return operation.tile_description.stages * CalculateSmemUsagePerStage(operation) + + +class ApiVersion(enum.Enum): + """ + Differentiate between CUTLASS 2.x and 3.x API versions + """ + + v2x = enum_auto() + v3x = enum_auto() + + +def api_version(arch, opclass, dtype): + """ + Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x + or 3.x for code emission. + + :param arch: compute capability of device on which to run + :type arch: int + :param opclass: class of the operation being performed + :type opclass: cutlass_library.OpcodeClass + :param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same) + :type dtype: cutlass_library.DataType + + :return: API version to be used in code emission + :rtype: ApiVersion + """ + if (arch in [90, 100, 101, 103] and + opclass == OpcodeClass.TensorOp and + (dtype != DataType.f64)): + return ApiVersion.v3x + else: + return ApiVersion.v2x + + +class EmissionType(enum.Enum): + """ + Tags for whether to emit a kernel- or device-level operation + """ + + Kernel = enum_auto() + Device = enum_auto() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..30e6bb3108ddd30e3776cf92b0671fce4fae5a93 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py @@ -0,0 +1,121 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import numpy as np + +import cutlass_cppgen +from cutlass_cppgen.utils.datatypes import is_numpy_tensor +from cutlass_cppgen.utils.lazy_import import lazy_import + +if cutlass_cppgen.use_rmm: + import rmm +else: + cudart = lazy_import("cuda.cudart") + + +class PoolMemoryManager: + def __init__(self, init_pool_size: int, max_pool_size: int) -> None: + self.pool = rmm.mr.PoolMemoryResource( + rmm.mr.CudaMemoryResource(), + initial_pool_size=init_pool_size, + maximum_pool_size=max_pool_size + ) + self.mr = rmm.mr.TrackingResourceAdaptor(self.pool) + rmm.mr.set_current_device_resource(self.mr) + + def pool_size(self): + return self.pool.pool_size() + + +class DevicePtrWrapper: + """ + Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer + (at least in terms of the interface used by the CUTLASS Python interface) + """ + def __init__(self, dev_ptr): + self.dev_ptr = dev_ptr + + @property + def ptr(self): + return self.dev_ptr + + +def _todevice(host_data): + """ + Helper for transferring host data to device memory + """ + if cutlass_cppgen.use_rmm: + return rmm.DeviceBuffer.to_device(host_data.tobytes()) + else: + nbytes = len(host_data.tobytes()) + dev_ptr_wrapper = device_mem_alloc(nbytes) + err, = cudart.cudaMemcpy( + dev_ptr_wrapper.ptr, + host_data.__array_interface__['data'][0], + nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice + ) + if err != cudart.cudaError_t.cudaSuccess: + raise Exception(f"cudaMemcpy failed with error {err}") + return dev_ptr_wrapper + + +def todevice(host_data, dtype=np.float32): + """ + Pass the host_data to device memory + """ + if isinstance(host_data, list): + return _todevice(np.array(host_data, dtype=dtype)) + elif is_numpy_tensor(host_data): + return _todevice(host_data) + + +def device_mem_alloc(size): + if cutlass_cppgen.use_rmm: + return rmm.DeviceBuffer(size=size) + else: + err, ptr = cudart.cudaMalloc(size) + if err != cudart.cudaError_t.cudaSuccess: + raise Exception(f"cudaMalloc failed with error {err}") + return DevicePtrWrapper(ptr) + + +def align_size(size, alignment=256): + return ((size + alignment - 1) // alignment) * alignment + + +def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34): + if cutlass_cppgen.use_rmm: + memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size) + return memory_pool + else: + return None diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..10ee67bc6f547d079b6d990e7abea69a16549c16 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py @@ -0,0 +1,140 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ctypes +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") + +from cutlass_cppgen.backend.utils.device import device_cc + +_supports_cluster_launch = None + + +def supports_cluster_launch(): + from cuda import __version__ + _version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")] + global _supports_cluster_launch + if _supports_cluster_launch is None: + major, minor = _version_splits[0], _version_splits[1] + _supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8)) + return _supports_cluster_launch + + +class LaunchConfiguration: + def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0): + self.grid = grid + self.block = block + self.shared_memory_capacity = smem + + +class ExecutableOperation: + def __init__(self, operation): + self.operation = operation + self.module = None + self.kernel = None + + def name(self): + return self.operation.procedural_name() + + def emit(self): + return "" + + def can_implement(self, configuration, arguments): + raise NotImplementedError() + + def get_host_workspace_size(self, arguments): + raise NotImplementedError() + + def get_device_workspace_size(self, arguments): + raise NotImplementedError() + + def plan(self, arguments): + raise NotImplementedError() + + def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None): + raise NotImplementedError() + + def run_with_clusters(self, launch_config, kernel_params, stream=None): + if not stream: + stream = cuda.CUstream(0) + if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"): + attr = cuda.CUlaunchAttribute() + attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape + attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + attrs = [attr] + + # Allow for non-portable cluster sizes + err, = cuda.cuFuncSetAttribute( + self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1) + if err != cuda.CUresult.CUDA_SUCCESS: + return err + else: + attrs = [] + + config = cuda.CUlaunchConfig() + config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid + config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block + config.blockDimZ = launch_config.block[2] + config.sharedMemBytes = launch_config.shared_memory_capacity + config.hStream = stream + config.attrs = attrs + config.numAttrs = len(attrs) + + err, = cuda.cuLaunchKernelEx( + config, f=self.kernel, kernelParams=kernel_params, extra=0) + return err + + def run_without_clusters(self, launch_config, kernel_params, stream=None): + if not stream: + stream = cuda.CUstream(0) + err, = cuda.cuLaunchKernel( + self.kernel, + launch_config.grid[0], launch_config.grid[1], launch_config.grid[2], + launch_config.block[0], launch_config.block[1], launch_config.block[2], + launch_config.shared_memory_capacity, + stream, + kernel_params, + 0) + + return err + + def run(self, host_workspace, device_workspace, launch_config, stream=None): + if not stream: + stream = cuda.CUstream(0) + cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace) + packed = (ctypes.c_void_p * 1)() + packed[0] = ctypes.addressof(cArg) + + if supports_cluster_launch(): + return self.run_with_clusters(launch_config, packed, stream) + else: + return self.run_without_clusters(launch_config, packed, stream) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..535cea2cb2a23ccbb29cce7233f42147ed2ea5eb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py @@ -0,0 +1,455 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +from __future__ import annotations + +import ctypes +from typing import Union + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +import numpy as np + +from cutlass_library import ( + DataTypeNames, + DataTypeSize, + DataTypeTag, + LayoutType, + SubstituteTemplate +) + +import cutlass_cppgen +from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params +from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend +from cutlass_cppgen.backend.library import TensorDescription +from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper +from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration +from cutlass_cppgen.shape import MatrixCoord +from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor + + +class ReductionOperation: + pass + + +class ReductionArguments: + """ + Arguments of reduction + """ + + def __init__( + self, + operation: ReductionOperation, + problem_size: "list[int]", + partitions: int, + workspace: cuda.CUdeviceptr, + destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]", + **kwargs, + ) -> None: + # tensor_C can be interpreted as the bias with bias=True in keyword args + if "bias" in kwargs.keys(): + self.bias = kwargs["bias"] + else: + # by default, tensor_C is not bias + self.bias = False + if "stream" in kwargs.keys(): + self.stream = kwargs["stream"] + else: + self.stream = cuda.CUstream(0) + + self.operation = operation + self.ptr_workspace = workspace + + # number of split-k partitions + self.partitions = partitions + + if is_numpy_tensor(destination): + self.host_D = destination + self.destination_buffer = NumpyFrontend.argument(destination, True) + self.source_buffer = NumpyFrontend.argument(source, False) + self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr) + self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr) + elif is_torch_tensor(destination): + self.ptr_destination = TorchFrontend.argument(destination) + self.ptr_source = TorchFrontend.argument(source) + elif isinstance(destination, cuda.CUdeviceptr): + self.ptr_destination = destination + self.ptr_source = source + else: + raise TypeError("unknown Type") + + self.problem_size = MatrixCoord_(problem_size[0], problem_size[1]) + + self.partition_stride = ( + problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8 + ) + + if "output_op" in kwargs.keys(): + self.output_op = kwargs["output_op"] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) + + self.get_arguments() + + @staticmethod + def get_tensor_ref( + extent: "tuple[int]", + device_ptr: cuda.CUdeviceptr, + layout: LayoutType, + ): + if layout == LayoutType.RowMajor: + return TensorRef2D_(int(device_ptr), extent[1]) + else: + raise ValueError(f"Unknown layout type {layout}") + + def get_arguments(self): + ref_workspace = ReductionArguments.get_tensor_ref( + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_workspace, + layout=LayoutType.RowMajor, + ) + if self.bias: + ref_source = ReductionArguments.get_tensor_ref( + extent=[0, 0], + device_ptr=self.ptr_source, + layout=LayoutType.RowMajor, + ) + else: + ref_source = ReductionArguments.get_tensor_ref( + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_source, + layout=LayoutType.RowMajor, + ) + + ref_destination = ReductionArguments.get_tensor_ref( + extent=[ + self.problem_size.row, + self.problem_size.column, + ], + device_ptr=self.ptr_destination, + layout=LayoutType.RowMajor, + ) + + self.c_arguments = self.operation.argument_type( + self.problem_size, + self.partitions, + self.partition_stride, + ref_workspace, + ref_destination, + ref_source, + self.output_op, + ) + + params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments)) + self.host_workspace = bytearray(params_.contents) + + def sync(self): + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + if hasattr(self, "host_D"): + (err,) = cuda.cuMemcpyDtoH( + self.host_D, + self.ptr_destination, + self.host_D.size * self.host_D.itemsize, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError("CUDA Error %s" % str(err)) + + self.free() + + def free(self): + """ + Frees allocated device-side memory + """ + # Free any device memory allocated manually + if not cutlass_cppgen.use_rmm: + for attr in ["destination_buffer", "source_buffer"]: + if hasattr(self, attr): + buf = getattr(self, attr) + if isinstance(buf, DevicePtrWrapper): + err, = cudart.cudaFree(buf.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + del buf + + +class ReductionRT(ExecutableOperation): + """ + ReductionRT manages the CUTLASS runtime components for reduction + """ + + KernelTemplate = r""" +extern "C" +__global__ void +${operation_name}(${operation_name}${operation_suffix}::Params params) { + + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + ${operation_name}${operation_suffix}::SharedStorage *shared_storage = + reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase); + + ${operation_name}${operation_suffix} op; + + op(params, *shared_storage); +} + """ + HostTemplate = r""" +extern "C" { + // Get the size of params in bytes + int ${operation_name}_get_param_size(){ + return sizeof(${operation_name}${operation_suffix}::Params); + } + + // Get the size of dynamic shared memory in bytes + int ${operation_name}_shared_memory_size() { + return int(sizeof(${operation_name}${operation_suffix}::SharedStorage)); + } + + // Get the params as byte array + char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){ + char *bytes = ((char*)(params)); + char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)]; + for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++) + output[i] = bytes[i]; + + return output; + } +} + """ + + def __init__(self, operation: ReductionOperation): + super().__init__(operation) + + self.operation: ReductionOperation = operation + self.emitter = EmitReductionInstance("_type") + + self.elements_per_access = self.operation.count + ( + self.argument_type, + self.epilogue_type, + ) = get_reduction_params(operation.epilogue_functor) + self.argtype = [ctypes.POINTER(self.argument_type)] + + def emit(self): + return self.emitter.emit(self.operation) + + def plan(self, arguments: ReductionArguments): + block_shape = [ + self.operation.shape.column // self.elements_per_access, + self.operation.shape.row, + 1, + ] + grid_shape = [ + (arguments.problem_size.row + self.operation.shape.row - 1) + // self.operation.shape.row, + (arguments.problem_size.column + self.operation.shape.column - 1) + // self.operation.shape.column, + 1, + ] + return LaunchConfiguration( + grid_shape, + block_shape, + self.shared_memory_capacity, + ) + + def initialize(self): + (err,) = cuda.cuFuncSetAttribute( + self.kernel, + attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + value=self.shared_memory_capacity, + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error: {err}") + + +class ReductionOperation: + """ + CUTLASS reduction Operation + """ + + def __init__( + self, + shape: MatrixCoord, + C: TensorDescription, + element_accumulator, + element_workspace=None, + element_compute=None, + epilogue_functor=None, + count: int = 1, + partitions_per_stage: int = 4, + ) -> None: + self.shape = shape + self.epilogue_functor = epilogue_functor + self.element_accumulator = element_accumulator + + if element_workspace is None: + self.element_workspace = element_accumulator + else: + self.element_workspace = element_workspace + + if element_compute is None: + self.element_compute = element_accumulator + else: + self.element_compute = element_compute + + self.element_output = C.element + self.C: TensorDescription = C + + # Reduce op processing size + self.count: int = count + + # Number of partitions to reduce per stage + self.partitions_per_stage: int = partitions_per_stage + + self.rt_module: ReductionRT = ReductionRT(self) + self.argument_type = self.rt_module.argument_type + self.epilogue_type = self.rt_module.epilogue_type + + def extended_name(self): + extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}" + + return SubstituteTemplate( + extend_name, + { + "element_workspace": DataTypeNames[self.element_workspace], + "element_accumulator": DataTypeNames[self.element_accumulator], + "element_compute": DataTypeNames[self.element_compute], + "element_output": DataTypeNames[self.element_output], + }, + ) + + def configuration_name(self): + """The full procedural name indicates architecture, extended name, tile size""" + + configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}" + + threadblock = "%dx%d" % ( + self.shape.row, + self.shape.column, + ) + + return SubstituteTemplate( + configuration_name, + { + "extended_name": self.extended_name(), + "threadblock": threadblock, + }, + ) + + def procedural_name(self): + """The full procedural name indicates architeture, extended name, tile size""" + return self.configuration_name() + + def run(self, arguments: ReductionArguments) -> cuda.CUresult: + """ + Configure and launch the cuda kernel with input arguments + """ + launch_config = self.rt_module.plan(arguments) + + host_workspace = arguments.host_workspace + device_workspace = None + + err = self.rt_module.run( + host_workspace, + device_workspace, + launch_config, + arguments.stream + ) + + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + return err + + +class EmitReductionInstance: + def __init__(self, operation_suffix="") -> None: + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + "cutlass/reduction/kernel/reduce_split_k.h", + "cutlass/reduction/thread/reduction_operators.h", + ] + self.template = """ +// Reduction kernel instance +using ${operation_name}_base = +typename cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<${shape_row}, ${shape_column}>, + ${epilogue_functor}, + cutlass::reduction::thread::ReduceAdd< + ${element_accumulator}, + ${element_output}, + ${count}>, + ${partition_per_stage}>; + +struct ${operation_name}${operation_suffix}: + public ${operation_name}_base { }; + """ + + def emit(self, operation: ReductionOperation): + vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element] + + values = { + "operation_name": operation.configuration_name(), + "operation_suffix": self.operation_suffix, + "shape_row": str(operation.shape.row), + "shape_column": str(operation.shape.column), + "epilogue_functor": operation.epilogue_functor.emit(), + "element_output": DataTypeTag[operation.element_output], + "epilogue_vector_length": str(epilogue_vector_length), + "element_accumulator": DataTypeTag[operation.element_accumulator], + "element_compute": DataTypeTag[operation.element_compute], + "element_workspace": DataTypeTag[operation.element_workspace], + "count": str(operation.count), + "partition_per_stage": str(operation.partitions_per_stage), + } + + return SubstituteTemplate(self.template, values) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py new file mode 100644 index 0000000000000000000000000000000000000000..fffa03360f7e0eb2f3a2a20e5c8a4e04d009bee9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py @@ -0,0 +1,35 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]" + +Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0bae3bac1163c55a698dfc8722c62ac85cb25abf --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py @@ -0,0 +1,33 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed4096a6f4b772a58702c2f4b089cc32d707614 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py @@ -0,0 +1,126 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for interacting with the device +""" +from __future__ import annotations + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") + +import cutlass_cppgen +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor + + +def check_cuda_errors(result: list): + """ + Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise, + returns the result contained in the remaining fields of `result`. + + :param result: the results of the `cudart` method, consisting of an error code and any method results + :type result: list + + :return: non-error-code results from the `results` parameter + """ + # `result` is of the format : (cudaError_t, result...) + err = result[0] + if err.value: + raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err))) + + if len(result) == 1: + return None + elif len(result) == 2: + return result[1] + else: + return result[1:] + + +def device_cc(device: int = -1) -> int: + """ + Returns the compute capability of the device with ID `device`. + + :param device: ID of the device to query + :type device: int + + :return: compute capability of the queried device (e.g., 80 for SM80) + :rtype: int + """ + if device == -1: + device = cutlass_cppgen.device_id() + + deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) + major = str(deviceProp.major) + minor = str(deviceProp.minor) + return int(major + minor) + + +def device_sm_count(device: int = -1): + if device == -1: + device = cutlass_cppgen.device_id() + err, device_sm_count = cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception( + "Failed to retireve SM count. " + f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}" + ) + + return device_sm_count + + +def to_device_ptr(tensor) -> cuda.CUdeviceptr: + """ + Converts a tensor to a CUdeviceptr + + :param tensor: tensor to convert + :type tensor: np.ndarray | torch.Tensor | cp.ndarray | int + + :return: device pointer + :rtype: cuda.CUdeviceptr + """ + if is_numpy_tensor(tensor): + ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0]) + elif is_torch_tensor(tensor): + ptr = cuda.CUdeviceptr(tensor.data_ptr()) + elif is_cupy_tensor(tensor): + ptr = cuda.CUdeviceptr(int(tensor.data.ptr)) + elif isinstance(tensor, cuda.CUdeviceptr): + ptr = tensor + elif isinstance(tensor, int): + ptr = cuda.CUdeviceptr(tensor) + else: + raise NotImplementedError(tensor) + + return ptr diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4121b59e57e26e8a32022916089e0916db4988 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py @@ -0,0 +1,33 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.emit.pytorch import pytorch diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py new file mode 100644 index 0000000000000000000000000000000000000000..58f94e15148f934c92318b586d63b669757ed5f0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py @@ -0,0 +1,267 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Common utilities for emitting CUTLASS kernels +""" + +import cutlass_cppgen + +# Strings used for printing information about the generation of emitted scripts +_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)" + + +_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR} +""" + + +_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR} +""" + +_CUTLASS_KERNEL_ARGS_2x = """ + typename DeviceKernel::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, // problem size + 1, + {alpha, beta}, + A, B, C, D, + 0, 0, 0, 0, // batch strides + DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda + DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc + DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd + }; +""" + +_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """ + typename DeviceKernel::Arguments arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, // problem size + 1, + {alpha, beta}, + A, B, C, D, + 0, 0, 0, 0, // batch strides + DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda + DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc + DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd + -1 // avail_sms + }; +""" + +_CUTLASS_KERNEL_RUN_GEMM_2x = """ +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +cutlass::Status ${name}_kernel_run(int M, int N, int K, + const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta) { + ${args} + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.initialize(arguments, + workspace.get(), + nullptr); // CUDA stream + + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op(); + return status; +} +""" + +_CUTLASS_KERNEL_RUN_GEMM_3x = """ +using StrideA = typename DeviceKernel::GemmKernel::StrideA; +using StrideB = typename DeviceKernel::GemmKernel::StrideB; +using StrideC = typename DeviceKernel::GemmKernel::StrideC; +using StrideD = typename DeviceKernel::GemmKernel::StrideD; + +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +cutlass::Status ${name}_kernel_run( + int M, int N, int K, int L, + const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) { + + typename DeviceKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, L}, // problem size + { + A, // ptrA + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A + B, // ptrB + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B + }, + { + {alpha, beta}, + C, // ptrC + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C + D, // ptrD + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D + }, + hw_info + }; + + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.run(arguments, + workspace.get(), + nullptr); // CUDA stream + + return status; +} +""" + + +_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """ +using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute; + +int threadblock_count = DeviceKernel::sufficient(); + +cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes, + DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D, + int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd, + ElementCompute alpha, ElementCompute beta) { + + typename DeviceKernel::Arguments arguments { + problem_sizes, + problem_count, + threadblock_count, + {alpha, beta}, + A, B, C, D, + lda, ldb, ldc, ldd + }; + + size_t workspace_size = DeviceKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + DeviceKernel gemm_op; + cutlass::Status status = gemm_op.initialize(arguments, + workspace.get(), + nullptr); // CUDA stream + + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = gemm_op(); + return status; +} +""" + + +_CUTLASS_KERNEL_RUN_CONV2D_2x = """ + +using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel; +namespace { +using TensorRefA = typename UnderlyingKernel::TensorRefA; +using TensorRefB = typename UnderlyingKernel::TensorRefB; +using TensorRefC = typename UnderlyingKernel::TensorRefC; +using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute; +} + +template +TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){ + cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord); + TensorRef tensor_ref(ptr, layout); + return tensor_ref; +} + +cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size, + UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B, + UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta, std::string split_k_mode, + cudaStream_t stream, int device_id=0) { + // create the tensor references + cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + + TensorRefA tensor_ref_A = get_tensor_ref(tensor_coord_A, A); + TensorRefB tensor_ref_B = get_tensor_ref(tensor_coord_B, B); + TensorRefC tensor_ref_C = get_tensor_ref(tensor_coord_C, C); + TensorRefC tensor_ref_D = get_tensor_ref(tensor_coord_C, D); + + cutlass::conv::SplitKMode mode; + if (split_k_mode == "serial") { + mode = cutlass::conv::SplitKMode::kSerial; + } else if (split_k_mode == "parallel") { + mode = cutlass::conv::SplitKMode::kParallel; + } else { + throw std::runtime_error("Invalid split_k_mode: " + split_k_mode); + } + + typename DeviceKernel::Arguments arguments{ + *problem_size, + tensor_ref_A, + tensor_ref_B, + tensor_ref_C, + tensor_ref_D, + {alpha, beta}, + mode + }; + + DeviceKernel implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + void* workspace_ptr = device_memory_allocation(workspace_size, device_id); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // + // Launch initialized CUTLASS kernel + // + status = implicit_gemm_op(stream); + + return status; +} +""" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..fe96f3ede11163da01520f972eb97282a2ab2b14 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py @@ -0,0 +1,936 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel. +If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method. + +Example usage with JIT compilation: + +.. highlight:: python +.. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor) + op = plan.construct() + mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True) + + # Generate inputs for the GEMM + A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] + + # Run the module + D = mod.run(A, B, C) + + +Example usage without JIT compilation: + +.. highlight:: python +.. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output') + +After this call, the directory ``output`` contains ``setup.py``, +``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from +within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``. + +The module can later be used in Python via: + +.. highlight:: python +.. code-block:: python + + import torch + import cutlass_gemm + + # Generate inputs for the GEMM + A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)] + + # Run the module + D = cutlass_gemm.run(A, B, C) +""" + +import logging +import os + +from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate + +from cutlass_cppgen import CUTLASS_PATH, logger, swizzle +from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal +from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation +from cutlass_cppgen.backend.library import ApiVersion +from cutlass_cppgen.emit import common +from cutlass_cppgen.utils.datatypes import is_torch_available + +if is_torch_available(): + import torch + + +_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" + +// helper function allocating the memory +void* device_memory_allocation(size_t size, int device_id=0) { + if (size > 0) { + torch::Device device(torch::kCUDA, device_id); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device); + at::Tensor device_tensor = torch::empty({(long)size,}, options); + return reinterpret_cast(device_tensor.data_ptr()); + } else { + return nullptr; + } +} + +${includes} +${declaration} +${impl} +""" + +_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f); + +// C++ interface +at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, float alpha=1.f, float beta=0.f) { + return ${name}_kernel(A, B, C, alpha, beta); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", py::overload_cast, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); +} +""" + +_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f); + +// C++ interface +std::vector ${name}(const std::vector& A, const std::vector& B, at::optional> C=at::nullopt, float alpha=1.f, float beta=0.f) { + return ${name}_kernel(A, B, C, alpha, beta); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", py::overload_cast&, const std::vector&, at::optional>, float, float>(&${name}), + py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f); +} +""" + +_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel( + const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1); + +// C++ interface +at::Tensor ${name}( + const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", + py::overload_cast< + const at::Tensor&, const at::Tensor&, at::optional, + std::tuple, std::tuple, std::tuple, float, float, std::string, int>( + &${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, + py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), + py::arg("alpha") = 1.f, py::arg("beta") = 0.f, + py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); +} +""" + +_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel( + std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1); + +// C++ interface +at::Tensor ${name}( + std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", + py::overload_cast< + std::tuple, const at::Tensor&, const at::Tensor&, at::optional, + std::tuple, std::tuple, std::tuple, float, float, std::string, int>( + &${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, + py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), + py::arg("alpha") = 1.f, py::arg("beta") = 0.f, + py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); +} +""" + +_PYTORCH_GEMM_INCLUDES = { + ApiVersion.v2x: """ +#include "cutlass/gemm/device/gemm_universal.h" +""", + ApiVersion.v3x: """ +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" +""", +} + +_PYTORCH_GROUPED_GEMM_INCLUDES = """ +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +""" + +_PYTORCH_CONV2D_INCLUDES = """ +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/kernel/default_conv2d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +""" + +_CUTLASS_TYPE_TO_TORCH_TYPE = { + DataType.f16: "torch::kF16", + DataType.f32: "torch::kF32", + DataType.f64: "torch::kF64", + DataType.s8: "torch::kI8", + DataType.s32: "torch::kI32", + DataType.bf16: "torch::kBFloat16", +} + +_PYTORCH_GEMM_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_GEMM_2x + + """ +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { + int M = A.size(0); + int N = B.size(1); + int K = A.size(1); + + typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->contiguous().data_ptr()); + at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); + + cutlass::Status status = ${name}_kernel_run(M, N, K, + reinterpret_cast(A.contiguous().data_ptr()), + reinterpret_cast(B.contiguous().data_ptr()), + ptrC, + reinterpret_cast(D.contiguous().data_ptr()), + ElementCompute(alpha), ElementCompute(beta)); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + +_PYTORCH_GEMM_IMPL_TEMPLATE_3x = ( + common._CUTLASS_KERNEL_RUN_GEMM_3x + + """ +bool hw_info_queried = false; +cutlass::KernelHardwareInfo hw_info; + +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C, float alpha, float beta) { + int M = A.size(0); + int N = B.size(1); + int K = A.size(1); + int L = 1; + + // Query hardware info if we haven't already + if (!hw_info_queried) { + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->contiguous().data_ptr()); + at::Tensor D = B.new_empty({M, N}, ${torch_type_C}); + + cutlass::Status status = ${name}_kernel_run(M, N, K, L, + reinterpret_cast(A.contiguous().data_ptr()), + reinterpret_cast(B.contiguous().data_ptr()), + ptrC, + reinterpret_cast(D.contiguous().data_ptr()), + ElementCompute(alpha), ElementCompute(beta), + hw_info); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + + +_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = ( + common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x + + """ +std::vector ${name}_kernel(const std::vector& A, const std::vector& B, at::optional> C, float alpha, float beta) { + size_t num = A.size(); + + // To avoid performing many small cudaMallocs and host-to-device copies, + // we serialize the grouped GEMM arguments on the host, allocate one + // large chunk of device memory, and perform a single cudaMemcpy to + // copy the host data to the device. Allocation overheads could be + // avoided by using a memory pool. + + // Calculate the total size of the data to be copied from host to device + size_t total_size = sizeof(cutlass::gemm::GemmCoord) + + sizeof(DeviceKernel::ElementA*) + + sizeof(DeviceKernel::ElementB*) + + sizeof(DeviceKernel::ElementC*) + + sizeof(DeviceKernel::ElementC*) + + sizeof(int64_t) + + sizeof(int64_t) + + sizeof(int64_t); + total_size *= num; + + // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple + // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system). + // To ensure that we don't end up having misaligned loads in the kernel, + // we pad to the nearest multiple of 8. + // + // Note that, even on a 32-bit system (for which sizeof(X*) will not equal + // sizeof(int64_t)), only padding between the list of GemmCoords and the + // list of ptr_As is sufficient because the set of four equal-length lists of pointers + // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always + // start on a multiple of 8. + int64_t padding = 8 - (total_size % 8); + total_size += padding; + + uint8_t* host_data = new uint8_t[total_size]; + cutlass::DeviceAllocation device_data(total_size); + + uint8_t* start = host_data; + cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast(start); + + // Apply the padding after the list of GemmCoords + start += num * sizeof(cutlass::gemm::GemmCoord) + padding; + + int64_t ptr_A_offset = start - host_data; + DeviceKernel::ElementA** ptr_A_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementA*); + + int64_t ptr_B_offset = start - host_data; + DeviceKernel::ElementB** ptr_B_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementB*); + + int64_t ptr_C_offset = start - host_data; + DeviceKernel::ElementC** ptr_C_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementC*); + + int64_t ptr_D_offset = start - host_data; + DeviceKernel::ElementC** ptr_D_host = reinterpret_cast(start); + start += num * sizeof(DeviceKernel::ElementC*); + + int64_t lda_offset = start - host_data; + int64_t* lda_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + int64_t ldb_offset = start - host_data; + int64_t* ldb_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + int64_t ldc_offset = start - host_data; + int64_t* ldc_host = reinterpret_cast(start); + start += num * sizeof(int64_t); + + std::vector D(num); + + bool need_C = (C != at::nullopt) && (beta != 0.f); + for (size_t i = 0; i < num; ++i) { + int M = A[i].size(0); + int N = B[i].size(1); + int K = A[i].size(1); + *(problem_sizes_host + i) = {M, N, K}; + *(ptr_A_host + i) = reinterpret_cast(A[i].contiguous().data_ptr()); + *(ptr_B_host + i) = reinterpret_cast(B[i].contiguous().data_ptr()); + + if (need_C) { + *(ptr_C_host + i) = reinterpret_cast(C->at(i).contiguous().data_ptr()); + } + else { + *(ptr_C_host + i) = nullptr; + } + + D[i] = B[i].new_empty({M, N}, ${torch_type_C}); + *(ptr_D_host + i) = reinterpret_cast(D[i].contiguous().data_ptr()); + + *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0); + *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0); + *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0); + } + + device_data.copy_from_host(host_data); + + cutlass::Status status = ${name}_kernel_run( + num, + reinterpret_cast(device_data.get()), + reinterpret_cast(device_data.get() + ptr_A_offset), + reinterpret_cast(device_data.get() + ptr_B_offset), + reinterpret_cast(device_data.get() + ptr_C_offset), + reinterpret_cast(device_data.get() + ptr_D_offset), + reinterpret_cast(device_data.get() + lda_offset), + reinterpret_cast(device_data.get() + ldb_offset), + reinterpret_cast(device_data.get() + ldc_offset), + reinterpret_cast(device_data.get() + ldc_offset), + ElementCompute(alpha), ElementCompute(beta)); + + delete[] host_data; + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" +) + +_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + cutlass::Status status = ${name}_kernel_run( + &problem_size, + reinterpret_cast(A.data_ptr()), + reinterpret_cast(B.data_ptr()), + ptrC, + reinterpret_cast(D.data_ptr()), + alpha, beta, + split_k_mode, stream, B.device().index()); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" + +_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S, P, Q; + N = A.size(0); + C_ = A.size(1); + H = A.size(2); + W = A.size(3); + + K = B.size(0); + R = B.size(2); + S = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + P = problem_size.P; + Q = problem_size.Q; + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::zeros({N, K, P, Q}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(std::tuple input_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S; + N = std::get<0>(input_size); + C_ = std::get<1>(input_size); + H = std::get<2>(input_size); + W = std::get<3>(input_size); + + K = B.size(0); + R = B.size(2); + S = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::empty({N, C_, H, W}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(std::tuple weight_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S; + K = std::get<0>(weight_size); + C_ = std::get<1>(weight_size); + R = std::get<2>(weight_size); + S = std::get<3>(weight_size); + + N = B.size(0); + H = B.size(2); + W = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::empty({K, C_, R, S}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +setup( + name='${name}', + ext_modules=[ + CUDAExtension('${name}', [ + '${name}.cpp', + '${name}_kernel.cu', + ], + include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'], + extra_compile_args={ + 'cxx': ['-std=c++17'], + 'nvcc': ['-std=c++17', ${extra_compile_args}], + }, + libraries=['cuda'] + ), + ], + cmdclass={ + 'build_ext': BuildExtension + }) + +""" + + +def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""): + """ + Generates a setup.py file for the extension + + :param name: name of the module to generate + :type name: str + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + :param extra_compile_args: additional arguments to pass to setup.py + :type extra_args: str + """ + setup_py_file = os.path.join(sourcedir, "setup.py") + setup_source = SubstituteTemplate( + _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args} + ) + with open(setup_py_file, "w") as outfile: + outfile.write(setup_source) + + +class _ArchListSetter: + """ + Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST`` + environment variable when building a PyTorch CUDA module. + + ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch + CUDA module should be compiled. + + For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of + ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the + compilation of the module. + + This utility wraps the building of a PyTorch CUDA module with a setting of this environment + variable according to the current compute capability being targetted. + + Example usage: + + .. highlight:: python + .. code-block:: python + + # Temporarily set TORCH_CUDA_ARCH_LIST="8.0" + with _ArchListSetter(80): + # Perform JIT compilation and loading of the module + mod = torch.utils.cpp_extension.load(...) + + :param cc: compute capability + :type cc: int + """ + + _TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST" + + def __init__(self, cc: int): + self.cc_str = ".".join(list(str(cc))) + + def __enter__(self): + """ + Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc`` + """ + self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST) + os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str + + return self + + def __exit__(self, exc_type, exc_val, traceback): + """ + Restores the old value of TORCH_CUDA_ARCH_LIST + """ + if self.old_arch_list is None: + del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] + else: + os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list + + +def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): + """ + JIT compiles and loads a PyTorch CUDA extension. + + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param cpp_file: path to file containing extension's C++ interface + :type cpp_file: str + :param cuda_file: path to file containing extension's CUDA interface + :type cuda_file: str + + :return: loaded PyTorch module + """ + + from torch.utils.cpp_extension import load + + extra_cuda_cflags = ["-std=c++17"] + if cc in [90, 100, 101, 103]: + # PyTorch does not currently add the sm_90a target when compute capability + # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target. + extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a") + + with _ArchListSetter(cc): + jitmodule = load( + name, + [cpp_file, cuda_file], + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=[ + os.path.join(CUTLASS_PATH, "include"), + os.path.join(CUTLASS_PATH, "tools/util/include"), + ], + extra_ldflags=["-lcuda"], + verbose=(logger.level == logging.DEBUG) + ) + return jitmodule + + +def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + extra_kw = {} + if op.api == ApiVersion.v3x: + impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x + else: + impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x + if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK: + extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K + else: + extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x + impl_template = ( + _PYTORCH_GEMM_IMPL_TEMPLATE_3x + if op.api == ApiVersion.v3x + else _PYTORCH_GEMM_IMPL_TEMPLATE_2x + ) + cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_GEMM_INCLUDES[op.api], + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + _PYTORCH_GEMM_CPP_TEMPLATE, + {"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + extra_compile_args = "" + if cc in [90, 100, 101, 103]: + extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'" + _generate_setup(name, sourcedir, extra_compile_args) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def _pytorch_grouped_gemm( + op, name: str, cc: int, jit: bool = False, sourcedir: str = "" +): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if op.api != ApiVersion.v2x: + raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x") + + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_GROUPED_GEMM_INCLUDES, + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE, + {"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + _generate_setup(name, sourcedir) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or + weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions + for H/W/R/S given the same P/Q. + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + extra_kw = {} + if op.conv_kind == ConvKind.Fprop: + impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE + elif op.conv_kind == ConvKind.Dgrad: + impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE + elif op.conv_kind == ConvKind.Wgrad: + impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE + extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize() + extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element] + cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_CONV2D_INCLUDES, + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + cpp_template, + {"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + _generate_setup(name, sourcedir) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + +def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + The result of this method is files within ``sourcedir`` that can be used for building + a PyTorch module. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + :return: loaded PyTorch module (if ``jit=True``) or None + """ + device_op = op.device_op() + if isinstance(op, GemmOperationUniversal): + return _pytorch_gemm(device_op, name, cc, jit, sourcedir) + elif isinstance(op, GemmOperationGrouped): + return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir) + elif isinstance(op, Conv2dOperation): + return _pytorch_conv2d(device_op, name, cc, jit, sourcedir) + else: + raise Exception( + f"Operation type {type(op)} is not currently supported for PyTorch emission." + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faf6896e99ba78130ede8e09be9b9115e9169541 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py @@ -0,0 +1,56 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.epilogue.epilogue import ( + get_activations, + get_activation_epilogue, + gelu, + hardswish, + identity, + leaky_relu, + relu, + sigmoid, + silu, + tanh, + trace +) + +from cutlass_cppgen.epilogue.evt_ops import ( + max, + multiply_add, + sum, + permute, + reshape, + maximum, + minimum, + exp +) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a17506ee2be609ed8d5b299114df52c55ca0cf --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py @@ -0,0 +1,176 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Registry of elementwise epilogues + +Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via +code like the following for GEMM: + +.. highlight:: python +.. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.activation = cutlass_cppgen.epilogue.relu +""" + +from cutlass_cppgen.backend import epilogue, device_cc + + +gelu = epilogue.gelu +hardswish = epilogue.hardswish +identity = epilogue.identity +leaky_relu = epilogue.leaky_relu +relu = epilogue.relu +sigmoid = epilogue.sigmoid +silu = epilogue.silu +tanh = epilogue.tanh + + +_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh] + + +def get_activations() -> list: + """ + Returns a list of available activation functions + + :return: list of available activation functions + :rtype: list + """ + return _activations + + +def get_activation_epilogue( + activation, + element_output, + elements_per_access, + element_accumulator, + element_compute, +): + """ + Return an epilogue corresponding to the activation function, data types, and alignment + used in the kernel + + :param activation: elementwise activation function to use + :param element_output: data type of the output + :param elements_per_access: alignment of operand C of the kernel + :type elements_per_access: int + :param element_accumulator: data type of the accumulated output C + :param element_compute: data type in which compute operations should be performed + + :return: epilogue functor + """ + if activation not in _activations: + raise Exception( + f"Unsupported activation type {activation}. Available activations are: {_activations}" + ) + + if activation == identity: + return epilogue.LinearCombination( + element_output, elements_per_access, element_accumulator, element_compute + ) + else: + return epilogue.LinearCombinationGeneric( + activation, + element_output, + elements_per_access, + element_accumulator, + element_compute, + ) + + +""" +Frontend for EVT that generates epilogue functor through tracing the input function +""" +from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend + + +def trace(fn, example_tensors, **kwargs): + """ + Trace `fn(**example_tensors)` and generates epilogue visitor + + :param fn or str: Python callable or string of the epilogue function + :param example_tensors: example inputs for fn + :type example_tensors: dict + + .. hightlight:: python + .. code-block:: python + import cutlass_cppgen.backend.evt + + # Define epilogue function as Python callable + def example_fn(accum, C, alpha, beta, gamma): + D = ((accum + C) * alpha - gamma) / beta + return D + + # Define the example tensors + example_inputs = { + "accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"), + "C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda") + } + + # Generate the epilogue functor + epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs) + """ + if callable(fn): + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc=None, **kwargs): + if not cc: + cc = device_cc() + super().__init__(cc, **kwargs) + pass + setattr(EpilogueFunctor, "__call__", staticmethod(fn)) + + epilogue_functor = EpilogueFunctor(**kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + elif isinstance(fn, str): + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc=None, **kwargs): + self.source = textwrap.dedent(fn) + if not cc: + cc = device_cc() + super().__init__(cc, **kwargs) + + def parse(self, example_inputs) -> None: + self.example_inputs = example_inputs + self.ast = ast.parse(self.source) + self.visit(self.ast) + + epilogue_functor = EpilogueFunctor(**kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + else: + raise NotImplementedError("Expect a callable Python function") diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7d8e2c01286886ffc936052c84205a60a5d869fb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Collection of builtin functions used for host reference in EVT +""" + +import numpy as np + +from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor + +if is_torch_available(): + import torch + + +def multiply_add(x, y, z): + return x * y + z + + +def sum(x, dim): + if is_numpy_tensor(x): + return x.sum(axis=tuple(dim)) + elif is_torch_tensor(x): + return torch.sum(x, dim) + + +def max(x, dim): + if is_numpy_tensor(x): + return x.max(axis=tuple(dim)) + elif is_torch_tensor(x): + return torch.amax(x, dim) + + +def maximum(x, y): + if is_numpy_tensor(x): + return np.maximum(x, y) + elif is_torch_tensor(x): + return torch.maximum(x, torch.tensor(y)) + + +def minimum(x, y): + if is_numpy_tensor(x): + return np.minimum(x, y) + elif is_torch_tensor(x): + return torch.minimum(x, torch.tensor(y)) + +def exp(x): + if is_numpy_tensor(x): + return np.exp(x) + elif is_torch_tensor(x): + return torch.exp(x) + + +############################################################################## +# Layout manipulate nodes +############################################################################## + +def permute(x, indices: tuple): + if is_numpy_tensor(x): + return np.transpose(x, axes=indices) + elif is_torch_tensor(x): + return x.permute(*indices) + + +def reshape(x, new_shape: tuple): + if is_numpy_tensor(x): + return np.reshape(x, newshape=new_shape) + elif is_torch_tensor(x): + return x.view(new_shape) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ea04419955f6a71225b6daaeab884dcc4e3399 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py @@ -0,0 +1,569 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Classes containing valid operations for a given compute capability and data types. +""" + +from itertools import combinations_with_replacement +import logging + +import cutlass_library +from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode + +import cutlass_cppgen +from cutlass_cppgen.utils.check import valid_stage_count +from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op + + +_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100] + + +class KernelsForDataType: + """ + Container class for keeping track of kernels that correspond to a particular combination + of data types for operands A, B, and accumulator + """ + + def __init__(self, datatype_comb: tuple, layout_comb: tuple): + self.datatype_comb = datatype_comb + self.layout_comb = layout_comb + self.math_operations = set() + + # Dictionary mapping from alignment (int) to a list of kernels that fit the alignment + # constraint for the data type combination + self.kernels_by_alignment = {} + + def add(self, operation): + """ + Add an operation to the list of supported kernels + """ + alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}" + if alignment_key not in self.kernels_by_alignment: + self.kernels_by_alignment[alignment_key] = [] + self.kernels_by_alignment[alignment_key].append(operation) + self.math_operations.add(operation.tile_description.math_instruction.math_operation) + + def alignments(self, operand: str): + """ + Returns an unsorted list of alignments supported by this data type combination + + :param operand: identifier of operand in question (e.g., A, B, C) + :type operand: str + + :return: unsorted list of alignments supported by this data type combination + :rtype: list + """ + operand_idx = self._operand_idx(operand) + return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()] + + @property + def all_operations(self): + """ + Returns a list of all operations supported by this data type combination + + :return: list of all operations supported by this data type combination + :rtype: list + """ + ops = [] + for _, alignment_ops in self.kernels_by_alignment.items(): + ops.extend(alignment_ops) + return ops + + def default_operation(self, math_operation: cutlass_cppgen.MathOperation): + key = sorted(list(self.kernels_by_alignment.keys()))[0] + kernels = self.kernels_by_alignment[key] + if math_operation is not None: + kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation] + return kernels[0] + + def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation): + """ + Returns operations satisfying the alignment constraints + + :param alignment_A: alignment constraint of operations to return + :type alignment_A: int + :param alignment_B: alignment constraint of operations to return + :type alignment_B: int + :param alignment_C: alignment constraint of operations to return + :type alignment_C: int + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: list of operations + :rtype: list + """ + key = f"{alignment_A} {alignment_B} {alignment_C}" + + if key not in self.kernels_by_alignment: + og_key = key + # Reconcile A, B, and C alignments by trying to align to the minimum + min_alignment = min(alignment_A, alignment_B, alignment_C) + key = f"{min_alignment} {min_alignment} {min_alignment}" + if key not in self.kernels_by_alignment: + # Finally, go through all available alignment combinations and find + # one for which all values are less than those passed in. + key = None + alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) + for align_A, align_B, align_C in alignments: + if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0: + key = f"{align_A} {align_B} {align_C}" + break + + if key is None: + raise Exception( + f"No operations of alignment {og_key} found for data type and layout " + f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments " + f"are {self.kernels_by_alignment.keys()}" + ) + + ops = self.kernels_by_alignment[key] + if math_operation is not None: + ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation] + return ops + + def _operand_idx(self, key: str) -> int: + operand_list = ["A", "B", "C"] + if key not in operand_list: + raise Exception(f"Unexpected operand {operand}") + + return operand_list.index(key) + + def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int: + """ + Returns the most preferable alignment for a given shape and layout + + :param shape: extent of each dimension of the tensor + :type shape: tuple + :param layout: layout of the tensor + :type layout: cutlass_cppgen.LayoutType + :param operand: descriptor of the operand in question + :type operand: str + + :return: maximum alignment supported by the data type combination and tensor size + :rtype: int + """ + operand_idx = self._operand_idx(operand) + + # Determine the leading dimension of the shape + if layout == cutlass_cppgen.LayoutType.ColumnMajor: + ld = shape[-2] + elif layout == cutlass_cppgen.LayoutType.RowMajor: + ld = shape[-1] + elif layout == cutlass_cppgen.LayoutType.TensorNHWC: + ld = shape[-1] + else: + raise Exception(f"Unexpected or unsupported layout {layout}") + + for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True): + alignment = int(alignments.split(" ")[operand_idx]) + if ld % alignment == 0: + return alignment + + # Default to alignment of 1 if no others match + return 1 + + def sort(self): + """ + Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape + """ + key = lambda op: ( + op.tile_description.threadblock_shape[0] + * op.tile_description.threadblock_shape[1] + * op.tile_description.threadblock_shape[2] + ) + for alignment in self.kernels_by_alignment.keys(): + self.kernels_by_alignment[alignment].sort(key=key, reverse=True) + + def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool: + """ + Returns whether `math_operation` is supported by at least one operation. + + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: whether math_operation is supported by at least one operation + :rtype: bool + """ + return math_operation is None or math_operation in self.math_operations + + +class ArchOptions: + """ + Structure for keeping track of kernels available on a given compute capability + + :param target_cc: compute capability of the device on which kernels will be run + :type target_cc: int + :param kernel_cc: compute capability of the kernels to generate + :type kernel_cc: int + :param operation_kind: type of operation to register + :type operation_kind: cutlass_library.OperationKind + :param gemm_kinds: types of GEMM operations that can be included + :type gemm_kinds: list + :param allowed_math_operations: types of primitive math operations allowed + :type allowed_math_operations: list + """ + + def __init__( + self, + target_cc: int, + kernel_cc: int, + operation_kind: cutlass_library.OperationKind, + gemm_kinds: list, + allowed_math_operations: list = [ + cutlass_library.MathOperation.multiply_add, + cutlass_library.MathOperation.multiply_add_saturate, + cutlass_library.MathOperation.multiply_add_mixed_input_upcast, + cutlass_library.MathOperation.multiply_add_fast_f32 + ] + ): + self.cc = kernel_cc + + # Dictionary with following structure: + # Key: OpcodeClass + # Value: Dictionary with the following structure: + # Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType), + # representing ((element_a, element_b, element_accumulator), (layout_a, layout_b)) + # Value: KernelsForDataType + self.operations_by_opclass = {} + self.op_class = None + self.allowed_math_operations = allowed_math_operations + + if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100: + return + + # Identify the method within CUTLASS generator script that generates kernel + # descriptions for the target CC + generate_function_name = "GenerateSM" + str(kernel_cc) + if not hasattr(cutlass_library.generator, generate_function_name): + cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}") + return + generate_function = getattr(cutlass_library.generator, generate_function_name) + + # Initialize a default manifest and populate it with valid kernel descriptions + # for the target CC + args = [ + "--kernels=all", + f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}" + ] + manifest_args = cutlass_library.generator.define_parser().parse_args(args) + manifest = cutlass_library.manifest.Manifest(manifest_args) + generate_function(manifest, cutlass_cppgen._nvcc_version) + + if operation_kind not in manifest.operations: + # No kernels generated for this architecture, this could be because the CUDA + # toolkit is insufficient to support operations in this CC + cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}") + return + + # Only one CC should be returned, given the setup above of calling only the generation scripts + # for a given CC + if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]: + raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version " + "is sufficient for the architecture in question.") + + # Iterate through the available operations for this operation kind and + # find available opclasses and data types + for name, op_list in manifest.operations[operation_kind][kernel_cc].items(): + for op in op_list: + + if operation_kind == cutlass_library.OperationKind.Gemm: + if op.gemm_kind not in gemm_kinds: + continue + + mi = op.tile_description.math_instruction + if mi.math_operation not in self.allowed_math_operations: + continue + + # Prune operations that don't fit in shared memory + td = td_from_profiler_op(op) + if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]: + continue + + if mi.opcode_class not in self.operations_by_opclass: + self.operations_by_opclass[mi.opcode_class] = {} + + datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator) + layout_comb = (op.A.layout, op.B.layout) + + # Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations + if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32): + # TF32 kernels only supported on SM80 and beyond + if self.cc < 80: + continue + elif self.cc == 90 or self.cc == 100: + if (op.A.element != cutlass_library.DataType.f32 + or op.B.element != cutlass_library.DataType.f32 + or op.C.element != cutlass_library.DataType.f32): + continue + + datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32) + + opclass_dict = self.operations_by_opclass[mi.opcode_class] + key = (datatype_comb, layout_comb) + if key not in opclass_dict: + opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb) + opclass_dict[key].add(op) + + # Set the default opclass to TensorOp, if available. Otherwise default to SIMT + if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass: + self.op_class = cutlass_library.OpcodeClass.TensorOp + else: + self.op_class = cutlass_library.OpcodeClass.Simt + + # The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels. + # Here, we generate additional versions via a generic TileDescription. + if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass: + self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {} + + if operation_kind == cutlass_library.OperationKind.Gemm: + types = [ + (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8), + (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), + (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), + (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), + ] + + # Add FP8 A/B/C + fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2] + for type_comb in combinations_with_replacement(fp8_types, 3): + types.append(type_comb) + + # Add FP8 A/B with FP32 C + for type_comb in combinations_with_replacement(fp8_types, 2): + types.append(type_comb + (cutlass_cppgen.DataType.f32,)) + + layouts = [ + (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor), + (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor), + (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor), + (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor), + ] + elif operation_kind == cutlass_library.OperationKind.Conv2d: + types = [ + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), + (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), + (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), + ] + + layouts = [ + (cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC), + ] + else: + raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.") + + alignment = 1 + epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination + swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8 + for type_comb in types: + for layout_comb in layouts: + comb = (type_comb, layout_comb) + if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]: + continue + + A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment) + B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment) + C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment) + math_inst = cutlass_library.MathInstruction( + [1, 1, 1], + type_comb[0], + type_comb[1], + type_comb[2], + cutlass_library.OpcodeClass.Simt, + cutlass_library.MathOperation.multiply_add + ) + + td = cutlass_library.TileDescription( + [128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024) + + # Prune operations that don't fit in shared memory + if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]: + continue + + new_kernels = KernelsForDataType(type_comb, layout_comb) + + if operation_kind == cutlass_library.OperationKind.Gemm: + new_operation = cutlass_library.manifest.GemmOperation( + cutlass_library.GemmKind.Universal, td.minimum_compute_capability, + td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor) + new_kernels.add(new_operation) + elif operation_kind == cutlass_library.OperationKind.Conv2d: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + new_operation = cutlass_library.manifest.Conv2dOperation( + conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td, + A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor, + group_mode=GroupMode.SingleGroup + ) + new_kernels.add(new_operation) + + self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels + + # Sort all operations + for oc in self.operations_by_opclass.keys(): + for comb in self.operations_by_opclass[oc].keys(): + self.operations_by_opclass[oc][comb].sort() + + def opclass_supports_combination( + self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation + ) -> bool: + """ + Returns whether the provided operation class supports the provided data type and layout combination + + :param op_class: operation class to consider + :type op_class: cutlass_library.OpcodeClass + :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator) + :type datatype_comb: tuple[cutlass_library.DataType] + :param layout_comb: tuple of data types for (layout_A, layout_B) + :type layout_comb: tuple[cutlass_library.LayoutType] + :param math_operation: math operation to consider or None if any can be considered + :type math_operation: cutlass_cppgen.MathOperation + + :return: set of operation classes that support the provided data type and layout combination + :rtype: set + """ + if op_class not in self.operations_by_opclass: + raise Exception(f"Unexpected or unsupported operation class {op_class}") + + if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)): + if math_operation is not None: + return operations.supports_math_operation(math_operation) + else: + return True + + return False + + + def supporting_opclasses( + self, + element_a: cutlass_library.DataType, + element_b: cutlass_library.DataType, + element_accumulator: cutlass_library.DataType, + layout_a: cutlass_library.LayoutType, + layout_b: cutlass_library.LayoutType, + math_operation: cutlass_library.MathOperation, + ) -> set: + """ + Returns a set of operation classes that support the provided data type combination + + :param element_a: data type of operand A + :type element_a: cutlass_library.DataType + :param element_b: data type of operand B + :type element_b: cutlass_library.DataType + :param element_accumulator: data type of accumulator + :type element_accumulator: cutlass_library.DataType + :param layout_a: layout of operand A + :type layout_a: cutlass_library.LayoutType + :param layout_b: layout of operand B + :type layout_b: cutlass_library.LayoutType + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: set of operation classes that support the provided data type combination + :rtype: set + """ + supporting_op_classes = set() + datatype_comb = (element_a, element_b, element_accumulator) + layout_comb = (layout_a, layout_b) + + for op_class in self.operations_by_opclass.keys(): + if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): + supporting_op_classes.add(op_class) + return supporting_op_classes + + def operations( + self, + op_class: cutlass_library.OpcodeClass, + element_a: cutlass_library.DataType, + element_b: cutlass_library.DataType, + element_accumulator: cutlass_library.DataType, + layout_a: cutlass_library.LayoutType, + layout_b: cutlass_library.LayoutType, + math_operation: cutlass_library.MathOperation, + ) -> KernelsForDataType: + """ + Returns whether the provided operation class supports the provided data type combination + + :param op_class: operation class to consider + :type op_class: cutlass_library.OpcodeClass + :param element_a: data type of operand A + :type element_a: cutlass_library.DataType + :param element_b: data type of operand B + :type element_b: cutlass_library.DataType + :param element_accumulator: data type of accumulator + :type element_accumulator: cutlass_library.DataType + :param layout_a: layout of operand A + :type layout_a: cutlass_library.LayoutType + :param layout_b: layout of operand B + :type layout_b: cutlass_library.LayoutType + :param math_operation: math operation to consider + :type math_operation: cutlass_cppgen.MathOperation + + :return: container of kernels by alignment supported by the provided combination of parameters + :rtype: KernelsForDataType + """ + datatype_comb = (element_a, element_b, element_accumulator) + layout_comb = (layout_a, layout_b) + if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation): + raise Exception( + f"Data type layout combination {datatype_comb}, {layout_comb} " + f"is not supported by opcode class {op_class} on CC {self.cc}." + ) + return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)] + + +class OptionRegistry: + """ + Container of all architecture-specific options + + :param target_cc: compute capability of the device on which operations will be run + :type target_cc: int + """ + + def __init__(self, target_cc: int): + self.registry = {} + + if target_cc > 100 and (target_cc not in [101, 103, 120, 121]): + raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.") + + gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x] + operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d] + # Construct options for each CC + for kernel_cc in _generator_ccs: + self.registry[kernel_cc] = {} + for opkind in operation_kinds: + self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds) + + def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions: + return self.registry.get(cc, None)[op_kind] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0286907040fb3ded84f989bfc9d14e740307f6a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py @@ -0,0 +1,36 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.op.gemm_grouped import GroupedGemm +from cutlass_cppgen.op.op import OperationBase diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..711b27da13b54e30f8b25e839ffc4f51ed80dc5c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py @@ -0,0 +1,997 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running CONVs + + The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run + CONV2D operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS CONVs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.Conv(A, B, C, D) + plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass_cppgen.op.Conv2d(kind="fprop", + # element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32) + + A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda') + B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda') + C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda') + D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda') + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + A = torch.rand((32, 128), dtype=torch.float32, device='cuda') + B = torch.rand((128, 256), dtype=torch.float32, device='cuda') + C = torch.zeros((32, 256), dtype=torch.float32, device='cuda') + D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda') + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + + # Do other work... + + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + # Do other work... + + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + plan.activation = cutlass_cppgen.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + args = plan.run() + + # Do other work... + + args.sync() +""" + +from __future__ import annotations +from typing import Optional +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +from cutlass_library import ( + ConvKind, + ConvMode, + DataTypeSize, + IteratorAlgorithm, + OperationKind, + SplitKMode, + StrideSupport, +) + +import cutlass_cppgen +from cutlass_cppgen import epilogue +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation +from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord +from cutlass_cppgen.utils import check, datatypes + + +class Conv2d(OperationBase): + """ + Constructs a ``Conv2d`` object. + + The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C, + along with the data type of output D and that used for accumulation, are bound to the ``Conv`` + object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation in fprop + + # Use the generic ``element`` parameter to concisely set all data types for operands to the same values. + Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32) + + # Explicitly specify the data types to use for A, B, C, and D. + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, + element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type as those passed in here). + # A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout + Conv2d(kind="fprop", A=A, B=B, C=C, D=D) + + # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit + # those passed in via the generic ``element`` + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + element=cutlass_cppgen.DataType.f32) + + The order of precedence for the setting of the data type for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor + 2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``) + + :param kind: the convolution kind (i.e. fprop, wgrad, and dgrad) + :type kind: str + :param A: tensor representing data type of operand A + :param B: tensor representing data type of operand B + :param C: tensor representing data type of operand C + :param D: tensor representing data type of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + """ + def __init__( + self, kind="fprop", + A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, + element=None, + element_A=None, element_B=None, element_C=None, element_D=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d) + # Verify the kernel cc + if self.current_cc in [90, 100, 101, 103]: + # The Conv2d kernel on Hopper (SM90) is currently unsupported + # Revert to use SM80-tagged kernels + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self.specified_kernel_cc = 80 + self._reset_options(80) + + # The arch is used in testing + self.arch = self.current_cc + self.name = "conv2d" + kind + + # The convolution kind. (concept: cutlass_library.library.ConvKind) + self.conv_kind = datatypes.getattr_enum(ConvKind, kind) + + # The element types (concept: cutlass library types) of A, B, C, and D + elements = [] + layouts = [] + + # Complete the data types based on user-provided arguments + for elt, tens, name in zip([element_A, element_B, element_C, element_D], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + + elt_to_set = None + lay_to_set = None + + if tens is not None: + elt_to_set, _ = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + + assert elt_to_set is not None + + # Currently we only support layout TensorNHWC + lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(lay_to_set) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + # Default inputs if none is supplied in run() + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + # We only specify the stride of the swizzling functor here + # The actual swizzling functor is determined in run based on conv_kind and stride + self._swizzling_stride = 1 + + # Arguments that will be set to default value in _reset_operations + # The default tile_description and op_class are fetched from manifest of cutlass library + self._tile_description = None + self.op_class = None + # The default identity epilogue will be created + self.epilogue_functor = None + + self._reset_operations() + + # Arguments that will be determined online based on arguments of "run" + # based on stride, input/output channels, alignment, and conv_kind + self._iterator_algorithm = None + self._stride_support = None + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b, self._math_operation + ) + + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt + else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}{math_op_str}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(epilogue.identity) + + self.alignment_pref_A = min( + 128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) + self.alignment_pref_B = min( + 128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) + self.alignment_pref_C = min( + 128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C"))) + + # + # Tile description Related + # + + @property + def tile_description(self) -> TileDescription: + """ + Returns the tile description + """ + return self._tile_description + + @tile_description.setter + def tile_description( + self, td=None): + """ + Set the tile description + + :param td: tile description + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys + { + "threadblock_shape": [int, int, int], + "warp_count": [int, int, int], + "stages": int, + "instruction_shape": [int, int, int] (optional), + "cluster_shape": [int, int, int] (optional) + } + """ + if td is None: + return + if isinstance(td, dict): + if self._tile_description is None: + op = self.possible_operations.default_operation(self._math_operation) + self._tile_description = datatypes.td_from_profiler_op(op) + if "cluster_shape" in td.keys(): + if td["cluster_shape"] != [1, 1, 1]: + cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.") + td["cluster_shape"] = [1, 1, 1] + td = self._tile_description.clone_and_update(td) + + valid, msg = self._valid_tile_description(td) + if valid: + self._tile_description = td + else: + raise Exception(msg) + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass_cppgen.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + valid, msg = check.valid_stage_count(self.cc, self.current_cc, td) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + descriptions = [] + description_str = [] + for op in self.possible_operations.all_operations: + td = datatypes.td_from_profiler_op(op) + + if self._math_operation is not None: + if td.math_instruction.math_operation != self._math_operation: + continue + + if str(td) not in description_str: + description_str.append(str(td)) + descriptions.append(td) + return descriptions + + # + # Swizzling functor Related + # + + @property + def swizzling_stride(self): + """ + Returns the stride of swizzling currently being used by the Conv2d + + :return: swizzing stride + """ + return self._swizzling_stride + + @swizzling_stride.setter + def swizzling_stride(self, stride: int): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if not isinstance(stride, int): + raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}") + self._swizzling_stride = stride + + def _propose_swizzling_functor(self, stride): + """ + Automatically propose the swizzling functor based on the stride + """ + if self.conv_kind == ConvKind.Dgrad: + if stride[0] != 1 or stride[1] != 1: + return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}") + + return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}") + + # + # Iterator Algorithm Related + # + + @property + def iterator_algorithm(self) -> IteratorAlgorithm: + """ + Returns the iterator algorithm + """ + return self._iterator_algorithm + + @iterator_algorithm.setter + def iterator_algorithm(self, alg: str): + """ + Sets the iterator algorithm + + :param alg: The iterator algorithm + :type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels" + """ + iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg) + + # Check if the iterator algorithm is valid + if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop: + raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.") + + self._iterator_algorithm = iterator_alg + + def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm: + """ + Propose a valid iterator algorithm based on problem size and alignment + """ + if self.conv_kind == ConvKind.Fprop: + # Check whether the fixed channel is applicable + if problem_size.C == alignment_a: + return IteratorAlgorithm.FixedChannels + elif (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + elif self.conv_kind == ConvKind.Dgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + elif self.conv_kind == ConvKind.Wgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + + def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool: + """ + Validate whether the user provide iterator algorithm works for the given problem size + """ + if self.conv_kind == ConvKind.Fprop: + if iterator_algorithm == IteratorAlgorithm.FixedChannels: + return problem_size.C == alignment_a + elif iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32) + elif iterator_algorithm == IteratorAlgorithm.FewChannels: + return problem_size.C % alignment_a == 0 + elif self.conv_kind == ConvKind.Dgrad: + if iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0) + elif self.conv_kind == ConvKind.Wgrad: + if iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0) + + return True + + # + # Stride Support Related + # + + def _propose_stride_support(self, stride): + if self.conv_kind == ConvKind.Dgrad: + if stride[0] == 1 and stride[1] == 1: + return StrideSupport.Unity + + return StrideSupport.Strided + + # + # Construct and Compilation + # + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation: + """ + Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current + kernel specification of the ``Conv2d`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass_library.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass_cppgen.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.Conv2dOperation + """ + # Get alignment + alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B) + alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C) + + tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + + if tile_description is None: + if self.tile_description is not None: + tile_description = self.tile_description + else: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + if iterator_algorithm is None: + # If the iterator algorithm is already set + if self.iterator_algorithm is not None: + iterator_algorithm = self.iterator_algorithm + else: + # Otherwise, we conservatively use the analytic iterator for correctness + iterator_algorithm = IteratorAlgorithm.Analytic + + if stride_support is None: + # If the stride support is already set + if self._stride_support is not None: + stride_support = self._stride_support + else: + # Otherwise, we assume strided + stride_support = StrideSupport.Strided + + if swizzling_functor is None: + # If the swizzling functor is already set + swizzling_functor = self._propose_swizzling_functor(stride=(2, 2)) + + if epilogue_functor is None: + if self.epilogue_functor is not None: + epilogue_functor = self.epilogue_functor + else: + epilogue_functor = self._create_epilogue_functor_activation(self._activation) + + # Reset the alignment of the epilogue functor + epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor) + + operation = Conv2dOperation( + conv_kind=self.conv_kind, + iterator_algorithm=iterator_algorithm, + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + stride_support=stride_support, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + ::param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass_library.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass_cppgen.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was compiled + :rtype: cutlass_cppgen.backend.Conv2dOperation + """ + + self.operation = self.construct( + tile_description, alignment_A, alignment_B, alignment_C, + iterator_algorithm, stride_support, swizzling_functor, epilogue_functor) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + # + # Run Related + # + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, _ = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type: + raise Exception(f'Tensor {name} with type and layout {dtype} ' + f'does not match the expected type of {ref_type}.') + + def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation): + if self.conv_kind == ConvKind.Fprop: + input = A + weight = B + output = C + output_tensor = "C" + elif self.conv_kind == ConvKind.Dgrad: + output = A + weight = B + input = C + output_tensor = "A" + elif self.conv_kind == ConvKind.Wgrad: + output = A + input = B + weight = C + output_tensor = "A" + else: + raise Exception(f"Convolution kind {self.conv_kind} is not supported") + + N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV") + K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV") + _, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV") + + problem_size = Conv2DProblemSize( + N_, H_, W_, C_, + K_, R_, S_, C_, + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1], + ConvMode.CrossCorrelation, + 1, 1 + ) + + if P_ != problem_size.P or Q_ != problem_size.Q: + raise Exception( + f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})") + + return problem_size + + def run(self, A=None, B=None, C=None, D=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), + alpha=None, beta=None, + split_k=("serial", 1), sync: bool = True, + print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in the call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1) + :param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0) + :param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1) + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param split_k: a tuple (split_k_mode, split_k_slices) + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.Conv2dArguments + """ + if not stream: + stream = cuda.CUstream(0) + super().run_setup() + + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + # handle the case when there is no C + if C is None: + if beta != 0: + raise Exception(f"With beta {beta} != 0, C has to be provided.") + else: + C = D + + # Construct problem size based on input + # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching + problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation) + + # Propose stride support based on input + stride_support = self._propose_stride_support(stride) + + # Propose swizzling functor + swizzling_functor = self._propose_swizzling_functor(stride) + + shape_a = datatypes.get_tensor_shape(A, op="CONV") + shape_b = datatypes.get_tensor_shape(B, op="CONV") + shape_c = datatypes.get_tensor_shape(C, op="CONV") + + # Get the alignment + alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B") + alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C") + + alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A) + alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B) + alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C) + + # Propose iterator algorithm based on input + if self._iterator_algorithm is None: + # Propose a default iterator algorithm based on the problem size + iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b) + else: + if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)): + iterator_algorithm = self._iterator_algorithm + else: + raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.") + + epilogue_args = [alpha, beta] + + if hasattr(self, "_activation_args"): + if isinstance(self._activation_args, list): + epilogue_args += self._activation_args + else: + epilogue_args.append(self._activation_args) + + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity) + else: + epilogue_functor = self.epilogue_functor + + # The alignment is determined by the iterator function (I believe) + self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support, + swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module) + + # Create reduction operation for parallel split-k + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor) + self.reduction_operation = ReductionOperation( + shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C, + element_accumulator=self._element_accumulator, + element_compute=self._element_accumulator, + epilogue_functor=epilogue_functor_reduction, + count=alignment_c + ) + if print_module: + print(self.reduction_operation.rt_module.emit()) + compiler.add_module([self.reduction_operation,]) + + arguments = Conv2dArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=self.operation.epilogue_type(*epilogue_args), + split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]), + split_k_slices=split_k[1], + stream=stream + ) + + self.operation.run(arguments) + + if split_k[0] == "parallel" and split_k[1] > 1: + implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind) + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[implicit_gemm_size.m, implicit_gemm_size.n], + partitions=split_k[1], + workspace=arguments.ptr_D, + destination=D, + source=C, + output_op=self.reduction_operation.epilogue_type(*epilogue_args), + stream=stream + ) + self.reduction_operation.run(reduction_arguments) + + if sync: + if split_k[0] == "parallel" and split_k[1] > 1: + reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() + else: + arguments.sync() + + return arguments + + # + # Helper functions + # + @staticmethod + def output_size(input_size, weight_size, padding, stride, dilation): + problem_size = Conv2DProblemSize( + *input_size, + *weight_size, + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1], + ConvMode.CrossCorrelation, + 1, 1 + ) + return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K) + + +# +# Easy to use interfaces for fprop, wgrad, and dgrad +# + +class Conv2dFprop(Conv2d): + def __init__( + self, + input=None, weight=None, C=None, output=None, alpha=1, beta=0, + element=None, + element_input=None, element_weight=None, element_C=None, element_output=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = input, weight, output + element_A, element_B, element_D = element_input, element_weight, element_output + super().__init__( + "fprop", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run( + self, input=None, weight=None, C=None, output=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + + if not stream: + stream = cuda.CUstream(0) + + A, B, D = input, weight, output + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) + + +class Conv2dDgrad(Conv2d): + def __init__( + self, + grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, weight, grad_input + element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input + super().__init__( + "dgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + # + if not stream: + stream = cuda.CUstream(0) + + A, B, D = grad_output, weight, grad_input + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) + + +class Conv2dWgrad(Conv2d): + def __init__( + self, + grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, input, grad_weight + element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight + super().__init__( + "wgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + if not stream: + stream = cuda.CUstream(0) + + A, B, D = grad_output, input, grad_weight + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f9b1ab43a1c45d0024e99e50e45813ba18866e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py @@ -0,0 +1,725 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run + GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.Gemm(A, B, C, D) + plan.run() + + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32, + # layout=cutlass_cppgen.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + + A0 = torch.rand((128, 256), device='cuda') + B0 = torch.rand((256, 64), device='cuda') + C0 = torch.zeros((128, 64), device='cuda') + D0 = torch.zeros((128, 64), device.'cuda') + plan.run(A0, B0, C0, D0) + + A = torch.rand((32, 128), device='cuda') + B = torch.rand((128, 256), device='cuda') + C = torch.zeros((32, 256), device='cuda') + D = torch.zeros((32, 256), device.'cuda') + plan.run(A1, B1, C1, D1) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.compile() + + # Do other work... + + plan.run(A0, B0, C0, D0) + + # Do other work... + + plan.run(A1, B1, C1, D1) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.activation = cutlass_cppgen.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + args = plan.run() + + # Do other work... + + args.sync() +""" +from __future__ import annotations +from typing import Optional +from math import prod + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_library import ( + DataType, + DataTypeSize, + GemmUniversalMode, + KernelScheduleSuffixes, +) + +import cutlass_cppgen +from cutlass_cppgen import epilogue, swizzle +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes + + +class Gemm(OperationBase): + """ + Constructs a ``Gemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime -- + these are not to be changed after a ``Gemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation. All operands are row major. + + # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts + # for operands to the same values. + Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``. + Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32, + element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type and layout as those passed in here). + # A, B, C, and D are row-major torch.Tensor objects of type torch.float32 + Gemm(A=A, B=B, C=C, D=D) + + # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is + # the same as that for D, at present) + Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor, + layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor) + + # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types + # and layouts will inherit those passed in via the generic ``element`` and ``layout`` + Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor, + element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + The order of precedence for the setting of the data type and layout for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor + 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``, ``layout``) + + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass_cppgen.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :param layout_A: layout of operand A + :type layout_A: cutlass_cppgen.LayoutType + :param layout_B: layout of operand B + :type layout_B: cutlass_cppgen.LayoutType + :param layout_C: layout of operand C + :type layout_C: cutlass_cppgen.LayoutType + :param layout_D: layout of operand D + :type layout_D: cutlass_cppgen.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc) + self.name = "gemm" + self.compiled = False + + elements = [] + layouts = [] + + # Check that at least one of the following is set for each tensor (illustrated assuming tensor A): + # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout`` + for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D], + [layout_A, layout_B, layout_C, layout_C], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if lay is not None and tens is not None: + raise Exception(f'Must not specify both layout_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + if lay is None and tens is None and layout is None: + raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.') + + elt_to_set = None + lay_to_set = None + if tens is not None: + elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + lay_to_set = lay if lay is not None else layout + + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(lay_to_set) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + self.epilogue_functor = None + self.op_class = None + self._tile_description = None + + self._reset_operations() + + self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1 + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b, self._math_operation) + + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt + else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}{math_op_str}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity) + + @property + def swizzling_functor(self): + """ + Returns the type of the swizzling functor currently being used by the GEMM + + :return: swizzing functor type + """ + return self._swizzling_functor + + @swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK: + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: + raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') + + if self.current_cc in [90, 100, 101, 103]: + raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+') + self._swizzling_functor = swizzling_functor + + # + # Tile description Related + # + + @property + def tile_description(self) -> TileDescription: + """ + Returns the tile description + """ + return self._tile_description + + @tile_description.setter + def tile_description( + self, td=None): + """ + Set the tile description + + :param td: tile description + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys + { + "threadblock_shape": [int, int, int], + "warp_count": [int, int, int], + "stages": int, + "instruction_shape": [int, int, int] (optional), + "cluster_shape": [int, int, int] (optional) + } + """ + if td is None: + return + if isinstance(td, dict): + if self._tile_description is None: + op = self.possible_operations.default_operation(self._math_operation) + self._tile_description = datatypes.td_from_profiler_op(op) + td = self._tile_description.clone_and_update(td) + + valid, msg = self._valid_tile_description(td) + if valid: + self._tile_description = td + else: + raise Exception(msg) + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass_cppgen.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler) + + if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0: + valid = False + msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103" + + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations] + if self._math_operation is not None: + tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation] + return tds + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal: + """ + Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.GemmOperationUniversal + """ + alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) + alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) + alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) + + tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + + if alignment_C is None: + alignment_C = max(self.possible_operations.alignments("C")) + if self._element_c != DataType.void: + alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C) + + if tile_description is None: + if self._tile_description is None: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + + # The selected op may have lower alignment than that determined above, so we must + # reset alignment here. + alignment_C = op.C.alignment + else: + tile_description = self._tile_description + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self._tile_description = tile_description + + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + + operation = GemmOperationUniversal( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param print_module: whether to print the emitted C++ code + :type print_module: bool + + :return: operation that was compiled + :rtype: cutlass_cppgen.backend.GemmOperationUniversal + """ + self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + def _verify_rank(self, tensor): + """ + Verifies that ``tensor`` has rank greater than 1 + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + """ + if len(tensor.shape) < 2: + raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}") + + def _get_batch_count(self, A, B, C, D) -> int: + """ + Returns the batch count specified by the tensors A, B, C, and D and verifies that these + tensors match in batch size. Presence of a batch dimension is detected by one of the + tensors being rank 3. If a batch dimension is present, it must be present in one of + operands A, B, or C (but need not be in all), and must be present in D. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple of batch count dimensions + :rtype: tuple + """ + A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1 + B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1 + + if 1 not in [A_batch, B_batch]: + if A_batch != B_batch: + raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}") + return max(A_batch, B_batch) + + def _get_batch_stride(self, tensor) -> int: + """ + Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0. + + :param tensor: tensor object to process + :type tensor: numpy/cupy/torch array/tensor object + + :return: stride between each matrix in the batch + :rtype: int + """ + if tensor is not None and len(tensor.shape) > 2: + return tensor.shape[-2] * tensor.shape[-1] + else: + return 0 + + def _get_problem_args(self, A, B, C, D) -> tuple: + """ + Returns the problem size and GEMM universal mode to use for the + given operands. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int) + :rtype: tuple + """ + M, K = A.shape[-2:] + N = B.shape[-1] + mode = GemmUniversalMode.Gemm + + batch_count = self._get_batch_count(A, B, C, D) + returned_batch_count = batch_count + + # If we are running a batched GEMM in which there is a nonzero batch stride + # only for A, then we can fold the batched dimension of A into the M dimension + # (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A + # and C are row major. A similar operation can be performed if only B has a nonzero + # batch dimension + if batch_count > 1: + A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor + B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor + C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor + + # Consider a Tensor to be batched if its rank is > 2 and + # the product of the modes beyond rank 2 equals our pre-determined batch size. + batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count) + + if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row: + M *= batch_count + returned_batch_count = 1 + elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row: + N *= batch_count + returned_batch_count = 1 + else: + mode = GemmUniversalMode.Batched + + return GemmCoord(M, N, K), mode, returned_batch_count + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, layout = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type or layout != ref_layout: + try: + # Attempt to transpose the tensor to fit the desired layout + tensor = tensor.transpose(-1, -2) + except: + raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' + f'does not match the expected type and ' + f'layout of ({ref_type}, {ref_layout}) and transpose failed.') + + def run(self, A=None, B=None, C=None, D=None, + alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None, + stream: Optional[cuda.CUstream] = None) -> GemmArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in this call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.GemmArguments + """ + if not stream: + stream = cuda.CUstream(0) + super().run_setup() + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + is_void_c = self._element_c == DataType.void + + self._verify_rank(A) + self._verify_rank(B) + if not is_void_c: + self._verify_rank(C) + self._verify_rank(D) + + alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") + + # Set C alignment based on D.shape so as to correctly get an alignment with void-C + # kernels, for which `C` is None. + alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C") + self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + problem_size, mode, batch_count = self._get_problem_args(A, B, C, D) + + if mode == GemmUniversalMode.Gemm or batch_count == 1: + kwargs = {'split_k_slices': 1} + else: + kwargs = { + 'batch': batch_count, + 'batch_strides': { + 'A': self._get_batch_stride(A), + 'B': self._get_batch_stride(B), + 'C': self._get_batch_stride(C), + 'D': self._get_batch_stride(D) + } + } + + kwargs['stream'] = stream + + if isinstance(self.epilogue_functor, EpilogueFunctorVisitor): + output_op = self.operation.epilogue_type(visitor_args) + else: + output_op = self.operation.epilogue_type(alpha, beta) + + arguments = GemmArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=output_op, + gemm_mode=mode, + **kwargs + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..59f90535c29a816541bc1a2155fea35afd1c94fd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py @@ -0,0 +1,269 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run + grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS grouped GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) +""" +from __future__ import annotations +from typing import Optional +from cutlass_library import DataTypeSize + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_cppgen.backend.gemm_operation import ( + GemmGroupedArguments, + GemmOperationGrouped, +) +from cutlass_cppgen.backend.library import ( + SchedulerMode, + TensorDescription, + TileDescription, +) +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes + + +class GroupedGemm(Gemm): + """ + Constructs a ``GroupedGemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime -- + these are not to be changed after a ``GroupedGemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. Please see the constructor + for ``Gemm`` for examples of these. + + :param cc: compute capability of device to generate kernels for + :type cc: int + :param A: tensor representing data type and layout of operands A + :param B: tensor representing data type and layout of operands B + :param C: tensor representing data type and layout of operands C + :param D: tensor representing data type and layout of operands D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass_cppgen.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :type layout_A: layout of operand A + :param layout_A: cutlass_cppgen.LayoutType + :type layout_B: layout of operand B + :param layout_B: cutlass_cppgen.LayoutType + :type layout_C: layout of operand C + :param layout_C: cutlass_cppgen.LayoutType + :type layout_D: layout of operand D + :param layout_D: cutlass_cppgen.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, + ): + super().__init__( + A=A, B=B, C=C, D=D, + alpha=alpha, beta=beta, + element_accumulator=element_accumulator, + element=element, layout=layout, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + cc=cc + ) + + # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80 + if self.current_cc in [90, 100, 101, 103]: + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + + self.name = "grouped_gemm" + + @Gemm.swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + raise Exception('Grouped GEMM does not currently support different swizzling functors') + + def construct(self, tile_description: TileDescription = None, + alignment_A: int = None, + alignment_B: int = None, + alignment_C: int = None) -> GemmOperationGrouped: + """ + Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.GemmOperationGrouped + """ + alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A"))) + alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B"))) + alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C"))) + + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + + tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + + if tile_description is None: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + operation = GemmOperationGrouped( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + precompute_mode=SchedulerMode.Device) + + return operation + + def run(self, A, B, C, D, + alpha=None, beta=None, sync: bool = True, + print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments: + """ + Runs the kernel currently specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: list of tensors representing data type and layout of operand A + :type A: list + :param B: list of tensors representing data type and layout of operand B + :type B: list + :param C: list of tensors representing data type and layout of operand C + :type C: list + :param D: list of tensors representing data type and layout of operand D + :type D: list + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.GemmGroupedArguments + """ + if not stream: + stream = cuda.CUstream(0) + + super().run_setup() + + if len(A) != len(B) or len(A) != len(C) or len(A) != len(D): + raise Exception("Lengths of A, B, C, and D lists must be equal") + + problem_sizes = [] + As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4)) + for i in range(len(A)): + As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A") + Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B") + Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C") + Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D") + problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1])) + + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As)) + alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs)) + alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs)) + self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + arguments = GemmGroupedArguments( + operation=self.operation, + problem_sizes=problem_sizes, + A=As, B=Bs, C=Cs, D=Ds, + output_op=self.operation.epilogue_type(alpha, beta), + stream=stream + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py new file mode 100644 index 0000000000000000000000000000000000000000..bebf07a7e5b83a1cf14cfecf19e90f730e305dce --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py @@ -0,0 +1,431 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) +""" + +from bisect import bisect_left + +from cutlass_library import ( + DataType, + DataTypeSize, + MathOperation, + OperationKind, + SharedMemPerCC +) + +import cutlass_cppgen +from cutlass_cppgen import get_option_registry +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.passes.util import cc_map +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity +from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs +from cutlass_cppgen.swizzle import get_swizzling_functors +from cutlass_cppgen.utils import datatypes, check + + +class OperationBase: + """ + Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) + """ + + def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm): + """ + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + :param operation_kind: class of operation that will be performed (e.g., GEMM, Conv) + :type operation_kind: cutlass_library.OperationKind + """ + self.operation_kind = operation_kind + self.cc = cc if cc is not None else device_cc() + self.specified_kernel_cc = kernel_cc is not None + self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) + self.tile_description = None + self._math_operation = None + + self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind) + + if self.options is None: + raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}") + + # Default activation function: identity + self._activation = identity + + def _find_closest_cc(self, cc: int) -> int: + """ + Returns the closest CC in _generator_ccs less than or equal to `cc` + + :param cc: compute capability to query + :type cc: int + + :returns: closest CC in _generator_ccs less than or equal to `cc` + :rtype: int + """ + if cc in _generator_ccs: + return cc + + # Find closest CC lower than this CC + idx = bisect_left(_generator_ccs, cc) + if idx == 0: + raise Exception(f'No valid CC to fall back to for {cc}') + return _generator_ccs[idx-1] + + def activations(self) -> list: + """ + Returns possible activation functions that can be used + + :return: list of activation functions that can be used + :rtype: list + """ + return get_activations() + + def swizzling_functors(self) -> list: + """ + Returns possible swizzling functions that can be used + + :return: list of swizzling functions that can be used + :rtype: list + """ + return get_swizzling_functors() + + def _reset_options(self, cc: int): + """ + Resets the kernel options based on cc + + :param cc: compute capability to reset to + :type cc: int + """ + if cc != self.current_cc: + if cc not in _generator_ccs: + raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.') + self.current_cc = cc + self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind) + + def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): + """ + Verifies the following properties: + 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) + 2) If ``scalar`` is not ``None``, its datatype must match matches the current version + set by the plan (i.e., those in ``ref_dtype``) + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. + + :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type scalar: numpy/cupy/torch scalar + :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_scalar: numpy/cupy/torch scalar + :param ref_dtype: data type for the scalar that this object was initialized to + :param name: identifier of the scalar to verify. Used in raising exceptions + :type name: str + + :return: valid scalar to use + :rtype: numpy/cupy/torch scalar + """ + if scalar is None: + if ref_scalar is None: + raise Exception(f"Scalar {name} must be set.") + return ref_scalar + if hasattr(scalar, "dtype"): + dtype = datatypes.library_type(scalar.dtype) + if dtype != ref_dtype: + raise Exception( + f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." + ) + return scalar + + def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): + """ + Verifies the following properties: + If ref_dtype is not void: + 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) + 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions + set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) + If ref_dtype is void: + Neither ``tensor`` nor ``ref_tensor`` are set + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + + :return: valid tensor object to use + :rtype: numpy/cupy/torch array/tensor object + """ + if ref_dtype == DataType.void: + if tensor is not None or ref_tensor is not None: + raise Exception("Operands with element DataType.void must not be provided a tensor") + return None + + if tensor is None: + if ref_tensor is None: + raise Exception(f"Tensor {name} must be set.") + return ref_tensor + + self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) + return tensor + + @property + def opclass(self) -> cutlass_cppgen.OpcodeClass: + """ + Returns the opcode class currently in use + + :return: opcode class currently in use + :rtype: cutlass_cppgen.OpcodeClass + """ + return self.op_class + + @opclass.setter + def opclass(self, oc: cutlass_cppgen.OpcodeClass): + if isinstance(oc, str): + oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc) + if oc in self.possible_op_classes: + self.op_class = oc + else: + raise Exception( + f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' + f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' + f'layout combination ({self._layout_a}, {self._layout_b}).') + + # Changing the op class also changes the possible operations available. Reset these. + self.possible_operations = self.options.operations( + self.op_class, self._element_a, self._element_b, + self._element_accumulator, self._layout_a, self._layout_b, self._math_operation) + + # Changing the op class changes the elements per access in the epilogue. Reset this. + if self.epilogue_functor is not None: + self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor) + + @property + def math_operation(self) -> cutlass_cppgen.MathOperation: + """ + Returns the math operation currently in use + + :return: math operation currently in use + :rtype: cutlass_cppgen.MathOperation + """ + return self._math_operation + + @math_operation.setter + def math_operation(self, mo: cutlass_cppgen.MathOperation): + if isinstance(mo, str): + mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo) + + if not self.specified_kernel_cc: + if self.current_cc in [90, 100, 101, 103]: + # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif self.current_cc in [90, 100, 101, 103]: + raise Exception("CUTLASS 3.0 kernels do not use different math operations. " + "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`" + "parameter when constructing the plan.") + + self._math_operation = mo + self._reset_operations() + + def _elements_per_access(self): + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: + return 1 + elif self._element_c != DataType.void: + return 128 // DataTypeSize[self._element_c] + else: + return 128 // max(self.possible_operations.alignments("C")) + + def _create_epilogue_functor_activation(self, activation): + """ + Returns the epilogue functor with given activation function + """ + if self.epilogue_functor is None: + elements_per_access = self._elements_per_access() + else: + elements_per_access = self.epilogue_functor.epilogue_vector_length + + if not self.specified_kernel_cc: + if self.current_cc in [90, 100, 101, 103] and activation != identity: + # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + if self._element_c != self._element_d: + raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None): + # SM80 fallback kernels are currently used. Since an identity activation is requested, + # we can switch back to using SM90 kernels. + self._reset_options(self.cc) + self._reset_operations(reset_epilogue=False) + else: + if self.current_cc in [90, 100, 101, 103] and activation != identity: + raise Exception("Epilogues with elementwise fusion are not currently supported " + "in the Python interface for 3.x kernels. To use 2.x kernels " + "with fused elementwise epilogues, do not set the `kernel_cc` " + "parameter when constructing the plan.") + + return get_activation_epilogue( + activation, + self._element_d, + elements_per_access, + self._element_accumulator, + self._element_accumulator, + ) + + def _reset_epilogue_functor_activation(self, activation): + """ + Set the epilogue functor based on the provided activation function + """ + self.epilogue_functor = self._create_epilogue_functor_activation(activation) + + def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor): + """ + Reset the alignment of the current epilogue functor based on alignment C + """ + if isinstance(epilogue_functor, EpilogueFunctorVisitor): + return epilogue_functor + + if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'): + # Identity epilogue does not have 'activation_functor' + activation = identity + else: + activation = epilogue_functor.activation_functor + + epilogue_functor = get_activation_epilogue( + activation, + self._element_d, + alignment, + self._element_accumulator, + self._element_accumulator, + ) + return epilogue_functor + + @property + def activation(self): + """ + Returns the type of the current activation function used + """ + if hasattr(self.epilogue_functor, "activation_functor"): + return self.epilogue_functor.activation_functor + else: + return identity + + @activation.setter + def activation(self, act): + """ + Sets the type of the activation function to use + Activation can come with a set of arguments + + :param act: type of activation function to use + :type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01) + + """ + if isinstance(act, tuple): + if isinstance(act[0], str): + act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0]) + else: + act_fn = act[0] + self._reset_epilogue_functor_activation(act_fn) + self._activation_args = act[1] + self._activation = act[0] + else: + if isinstance(act, str): + act = getattr(cutlass_cppgen.backend.epilogue, act) + self._reset_epilogue_functor_activation(act) + self._activation = act + + @property + def epilogue_visitor(self): + """ + Return the epilogue functor + """ + return self.epilogue_functor + + @epilogue_visitor.setter + def epilogue_visitor(self, visitor): + """ + Create the epilogue visitor + """ + self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor) + + # The epilogue_functor may consume too much shared memory + # Reset the possible operations + if self.cc not in [90, 100, 101, 103]: + # The shared memory is only a concern for sm90+ epilogue + # In sm80, the epilogue and mainloop share the shared memory + return + + datatype_comb = self.possible_operations.datatype_comb + layout_comb = self.possible_operations.layout_comb + new_possible_operations = KernelsForDataType(datatype_comb, layout_comb) + for operation in self.possible_operations.all_operations: + td = datatypes.td_from_profiler_op(operation) + # Filter invalid epilogue schedules + if cc_map[self.cc] == 90 and td.epilogue_schedule not in [ + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized, + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]: + continue + epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td) + + # Verify the maximum number of mainloop stages + mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm) + smem_capacity_bytes = SharedMemPerCC[self.cc] << 10 + mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage + if mainloop_stages < 2: + # Mainloop stages must >= 2 + continue + + new_possible_operations.add(operation) + if len(new_possible_operations.all_operations) == 0: + raise RuntimeError( + "The epilogue consumes too much shared memory. " + "No valid tile description is found in the generator.") + self.possible_operations = new_possible_operations + + + def run_setup(self): + """ + Steps that must be taken before caling `plan.run()` + """ + # Initialize the memory pool if, if not already done + cutlass_cppgen.get_memory_pool() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py new file mode 100644 index 0000000000000000000000000000000000000000..a718f9bb4432f1f51457661abe27e24ea818aba4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py @@ -0,0 +1,184 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for expressing shapes +""" + +from cutlass_library import ( + ConvMode, + ConvKind, + LayoutType +) +from cutlass_cppgen.backend.c_types import ( + Conv2DProblemSize_, + GemmCoord_, + GemmCoordBatched_ +) + + +class MatrixCoord: + def __init__(self, row, col): + self._row = row + self._col = col + + @property + def row(self): + return self._row + + @property + def column(self): + return self._col + + def leading_dimension(self, layout: LayoutType) -> int: + """ + Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord. + + :param layout: layout of matrix + :type layout: cutlass_library.LayoutType + + :returns: leading dimension + :rtype: int + """ + if layout == LayoutType.RowMajor: + return self._col + elif layout == LayoutType.ColumnMajor: + return self._row + else: + raise Exception(f'Unsupported layout for leading dimension calculation: {layout}') + + +class GemmCoord: + def __init__(self, m: int, n: int, k: int): + self._m = m + self._n = n + self._k = k + + @property + def m(self) -> int: + return self._m + + @property + def n(self) -> int: + return self._n + + @property + def k(self) -> int: + return self._k + + @property + def mk(self) -> MatrixCoord: + return MatrixCoord(self._m, self._k) + + @property + def mn(self) -> MatrixCoord: + return MatrixCoord(self._m, self._n) + + @property + def kn(self) -> MatrixCoord: + return MatrixCoord(self._k, self._n) + + @property + def ctype(self) -> GemmCoord_: + return GemmCoord_(self._m, self._n, self._k) + + def batched_ctype(self, batch_count: int) -> GemmCoordBatched_: + return GemmCoordBatched_(self._m, self._n, self._k, batch_count) + + +class Conv2DProblemSize: + def __init__( + self, n: int, h: int, w: int, c: int, + k: int, r: int, s: int, c_: int, + pad_h: int, pad_w: int, stride_h: int, stride_w: int, + dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation, + split_k_slices: int=1, groups: int=1): + + self.N = n + self.H = h + self.W = w + self.C = c + self.K = k + self.R = r + self.S = s + self.pad_h = pad_h + self.pad_w = pad_w + self.stride_h = stride_h + self.stride_w = stride_w + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.mode = int(mode) + self.split_k_slices = split_k_slices + self.groups = groups + self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1 + self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1 + + @property + def ctype(self) -> Conv2DProblemSize_: + return Conv2DProblemSize_(self) + + def implicit_gemm_size(self, kind: ConvKind): + if kind == ConvKind.Fprop: + return GemmCoord( + self.N * self.P * self.Q, + self.K, + self.R * self.S * self.C // self.groups + ) + elif kind == ConvKind.Dgrad: + return GemmCoord( + self.N * self.H * self.W, + self.C, + self.R * self.S * self.K + ) + elif kind == ConvKind.Wgrad: + return GemmCoord( + self.K, + self.R * self.S * self.C, + self.N * self.P * self.Q + ) + + @staticmethod + def from_sizes(input_size, weight_size): + K, R, S, _ = weight_size + pad_h = R // 2 + pad_w = S // 2 + stride_h = 1 + stride_w = 1 + dilation_h = 1 + dilation_w = 1 + return Conv2DProblemSize( + *input_size, + *weight_size, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd9483415ea36716bf4643d27b8d92f3e9878a5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py @@ -0,0 +1,65 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Registry of swizzling functions +""" + +from cutlass_library import SwizzlingFunctor + + +IdentitySwizzle1 = SwizzlingFunctor.Identity1 +IdentitySwizzle2 = SwizzlingFunctor.Identity2 +IdentitySwizzle4 = SwizzlingFunctor.Identity4 +IdentitySwizzle8 = SwizzlingFunctor.Identity8 +HorizontalSwizzle = SwizzlingFunctor.Horizontal +ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK +StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1 +StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4 +StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal + + +_swizzling_functors = [ + IdentitySwizzle1, + IdentitySwizzle2, + IdentitySwizzle4, + IdentitySwizzle8, + HorizontalSwizzle, + ThreadblockSwizzleStreamK, + StridedDgradIdentitySwizzle1, + StridedDgradIdentitySwizzle4, + StridedDgradHorizontalSwizzle, +] + + +def get_swizzling_functors(): + return _swizzling_functors diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75d8416a15070ddcf2c6270248ccd9deff8e2137 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py @@ -0,0 +1,41 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.utils.check import ( + alignment_or_default, + calculate_smem_usage, + calculate_smem_usage_per_stage, + valid_cluster_shape, + valid_schedule, + valid_stage_count, + update_alignment, +) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py new file mode 100644 index 0000000000000000000000000000000000000000..108f268b4bc54ec0839afb5c1602ba63e5b98743 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py @@ -0,0 +1,262 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for checking constraints on kernels and calculating kernel attributes +""" + +import ctypes + +from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC + +import cutlass_cppgen +from cutlass_cppgen.backend.library import TileDescription + + +def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int: + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :param td: tile description to compute shared memory of + :type td: TileDescription + :param operation_kind: identifier for the type of operation being performed + :type operation_kind: cutlass_library.OperationKind + + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = td.blackwell_threadblock_shape + if td.is_2sm: + m //= 2 + + if operation_kind == OperationKind.Gemm: + stage_barrier_bytes = 32 + return ( + (DataTypeSize[td.math_instruction.element_a] * m * k // 8) + + (DataTypeSize[td.math_instruction.element_b] * k * n // 8) + + stage_barrier_bytes + ) + else: + raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}") + + +def calculate_smem_usage(operation) -> int: + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :return: number of bytes of shared memory consumed by the operation + :return: int + """ + _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind) + return _per_stage * operation.tile_description.stages + + +def valid_stage_count( + cc: int, + kernel_cc: int, + td: TileDescription, + element_C: cutlass_cppgen.DataType = None, + element_D: cutlass_cppgen.DataType = None, + verbose: bool = True) -> tuple: + """ + Checks whether a device with `cc` supports the number of stages within `tile_description`, both + based on raw limits on the number of stages and based on shared memory capacity + + :param cc: compute capability of device in question + :type cc: int + :param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS) + :type kernel_cc: int + :param td: tile description to check + :type td: TileDescription + :param element_C: data type of operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type of operand D + :type element_D: cutlass_cppgen.DataType + :param verbose: whether to log warnings + :type verbose: bool + + :return: tuple with the first element indicating whether the provided tile description is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + if kernel_cc in [90, 100, 101, 103]: + if (td.stages is None or td.stages == 0): + # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically + # determines the stage count to use. Thus, all settings are valid in these scenarios. + return (True, "") + elif verbose: + cutlass_cppgen.logger.warning( + "Setting an explicit stage count for SM90 kernels currently may " + "result in compilation errors if the combination of tile shape, " + "stage count, and shared memory requirement of the epilogue exceeds " + "the available shared memory per SM.") + + if td.stages <= 0: + return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.") + + if cc < 80 and td.stages != 2: + return (False, f"Tile description has stage count of {td.stages}, " + f"but only 2 stages are supported on SM{cc}.") + + # The calculation below does not consider shared memory used by the epilogue and, thus, + # only catches cases in which the mainloop exceeds the device's shared memory capacity. + # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the + # mainloop and epilogue is shared. + smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm) + smem_usage_mainloop = (smem_per_stage * td.stages) + smem_arch = SharedMemPerCC[cc] << 10 + if smem_usage_mainloop > smem_arch: + return ( False, + "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n" + f"Details:\n" + f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and " + f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n" + f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.") + + return (True, "") + + +def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: + """ + Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`. + + :param cc: compute capability of device in question + :type cc: int + :param cluster_shape: dimensions of thread block cluster shape to check + :type cluster_shape: list + + :return: tuple with the first element indicating whether the provided cluster shape is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + + if cc < 90 or cc in [120, 121]: + if cluster_shape != [1, 1, 1]: + return (False, + f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of " + f"{cluster_shape} for SM{cc}.") + else: + return (True, "") + + if len(cluster_shape) != 3: + return (False, + f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}") + + if cluster_shape[2] != 1: + return (False, + "CUTLASS kernels currently require the third dimension of cluster shape to be 1. " + f"Received cluster shape of {cluster_shape}.") + + return (True, "") + + +def valid_schedule( + cc: int, + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple: + """ + Checks that the kernel and epilogue schedules passed in are a valid combination for + a device of compute capability ``cc``. + + :param cc: compute capability of device in question + :type cc: int + :param kernel_schedule: kernel schedule type + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue schedule type + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param tile_scheduler: tile scheduler type + :type tile_scheduler: cutlass_cppgen.TileSchedulerType + + :return: tuple with the first element indicating whether the provided schedules are + valid for the provided device and the second element being an error message + :rtype: tuple + """ + kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto) + epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto) + tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default) + if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default): + return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)") + + if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)): + return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") + + if not tile_scheduler_default: + cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative] + if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): + return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") + return (True, "") + + +def alignment_or_default(alignment_provided: int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment + + +def update_alignment(alignment_provided:int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + if alignment_provided % default_alignment == 0: + return default_alignment + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py new file mode 100644 index 0000000000000000000000000000000000000000..c03a834dc47871bebe618752e4775a0a7434ff78 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py @@ -0,0 +1,362 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for converting between frontend datatypes and CUTLASS datatypes +""" + +import cutlass_cppgen +from cutlass_library import ( + DataTypeSize, + MathOperation, + MathInstruction +) +from cutlass_cppgen.backend.library import ( + TileDescription, +) + +bfloat16_available = None +cupy_available = None +numpy_available = None +torch_available = None +_library_to_cupy_dict = None +_library_to_numpy_dict = None +_library_to_torch_dict = None +_torch_to_library_dict = None + + +def is_numpy_available(): + global numpy_available, _library_to_numpy_dict + if numpy_available is None: + try: + import numpy as np + + numpy_available = True + _library_to_numpy_dict = { + cutlass_cppgen.DataType.f16: np.float16, + cutlass_cppgen.DataType.f32: np.float32, + cutlass_cppgen.DataType.f64: np.float64, + cutlass_cppgen.DataType.s8: np.int8, + cutlass_cppgen.DataType.s32: np.int32, + } + except ImportError: + numpy_available = False + _library_to_numpy_dict = {} + return numpy_available + + +def is_numpy_tensor(inp) -> bool: + if is_numpy_available(): + import numpy as np + return isinstance(inp, np.ndarray) + return False + + +def numpy_library_type(inp) -> cutlass_cppgen.DataType: + if is_numpy_available(): + import numpy as np + if inp == np.float16: + return cutlass_cppgen.DataType.f16 + elif inp == np.float32: + return cutlass_cppgen.DataType.f32 + elif inp == np.float64: + return cutlass_cppgen.DataType.f64 + elif inp == np.int8: + return cutlass_cppgen.DataType.s8 + elif inp == np.int32: + return cutlass_cppgen.DataType.s32 + return None + + +def numpy_type(inp): + return _library_to_numpy_dict.get(inp, None) + + +def is_cupy_available(): + global cupy_available + if cupy_available is None: + try: + import cupy as cp + + cupy_available = True + _library_to_cupy_dict = { + cutlass_cppgen.DataType.f16: cp.float16, + cutlass_cppgen.DataType.f32: cp.float32, + cutlass_cppgen.DataType.f64: cp.float64, + cutlass_cppgen.DataType.s8: cp.int8, + cutlass_cppgen.DataType.s32: cp.int32, + } + except ImportError: + cupy_available = False + _library_to_cupy_dict = {} + return cupy_available + + +def is_cupy_tensor(inp) -> bool: + if is_cupy_available(): + import cupy as cp + return isinstance(inp, cp.ndarray) + return False + + +def cupy_library_type(inp) -> cutlass_cppgen.DataType: + if is_cupy_available(): + import cupy as cp + if inp == cp.float16: + return cutlass_cppgen.DataType.f16 + elif inp == cp.float32: + return cutlass_cppgen.DataType.f32 + elif inp == cp.float64: + return cutlass_cppgen.DataType.f64 + return None + + +def cupy_type(inp): + return _library_to_cupy_dict.get(inp, None) + + +def is_torch_available(): + global torch_available, _library_to_torch_dict, _torch_to_library_dict + if torch_available is None: + try: + import torch + + torch_available = True + _torch_to_library_dict = { + torch.half: cutlass_cppgen.DataType.f16, + torch.float16: cutlass_cppgen.DataType.f16, + torch.bfloat16: cutlass_cppgen.DataType.bf16, + torch.float: cutlass_cppgen.DataType.f32, + torch.float32: cutlass_cppgen.DataType.f32, + torch.double: cutlass_cppgen.DataType.f64, + torch.float64: cutlass_cppgen.DataType.f64, + torch.int8: cutlass_cppgen.DataType.s8, + torch.int32: cutlass_cppgen.DataType.s32, + torch.uint8: cutlass_cppgen.DataType.u8, + } + + _library_to_torch_dict = { + cutlass_cppgen.DataType.f16: torch.half, + cutlass_cppgen.DataType.f16: torch.float16, + cutlass_cppgen.DataType.bf16: torch.bfloat16, + cutlass_cppgen.DataType.f32: torch.float, + cutlass_cppgen.DataType.f32: torch.float32, + cutlass_cppgen.DataType.f64: torch.double, + cutlass_cppgen.DataType.f64: torch.float64, + cutlass_cppgen.DataType.s8: torch.int8, + cutlass_cppgen.DataType.s32: torch.int32, + cutlass_cppgen.DataType.u8: torch.uint8, + } + + def possibly_add_type(torch_type_name, cutlass_type): + # Only try adding the type if the version of torch being used supports it + if hasattr(torch, torch_type_name): + torch_type = getattr(torch, torch_type_name) + _torch_to_library_dict[torch_type] = cutlass_type + _library_to_torch_dict[cutlass_type] = torch_type + + possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3) + possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2) + + except ImportError: + torch_available = False + _torch_to_library_dict = {} + _library_to_torch_dict = {} + return torch_available + + +def is_torch_tensor(inp) -> bool: + if is_torch_available(): + import torch + return isinstance(inp, torch.Tensor) + return False + + +def torch_library_type(inp) -> cutlass_cppgen.DataType: + return _torch_to_library_dict.get(inp, None) + + +def torch_type(inp): + return _library_to_torch_dict.get(inp, None) + + +def is_bfloat16_available(): + global bfloat16_available + + if bfloat16_available is None: + try: + import bfloat16 + + bfloat16_available = True + except ImportError: + bfloat16_available = False + return bfloat16_available + + +def bfloat16_library_type(inp) -> cutlass_cppgen.DataType: + if is_bfloat16_available(): + import bfloat16 + if inp == bfloat16.bfloat16: + return cutlass_cppgen.DataType.bf16 + + +def bfloat16_type(inp): + if is_bfloat16_available(): + import bfloat16 + if inp == cutlass_cppgen.DataType.bf16: + return bfloat16.bfloat16 + + +def library_type(inp): + if inp in DataTypeSize: + return inp + + for cvt_fn in [ + bfloat16_library_type, + cupy_library_type, + numpy_library_type, + torch_library_type, + ]: + out = cvt_fn(inp) + if out is not None: + return out + + raise Exception(f"No available conversion from type {inp} to a library type.") + + +def _tensor_from_numpy(np_tensor): + dtype = library_type(np_tensor.dtype) + if np_tensor.flags.c_contiguous: + layout = cutlass_cppgen.LayoutType.RowMajor + elif np_tensor.flags.f_contiguous: + layout = cutlass_cppgen.LayoutType.ColumnMajor + return (dtype, layout) + + +def _tensor_from_torch(pt_tensor): + dtype = library_type(pt_tensor.dtype) + return (dtype, cutlass_cppgen.LayoutType.RowMajor) + + +def get_datatype_and_layout(tensor): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): + return _tensor_from_numpy(tensor) + elif is_torch_tensor(tensor): + return _tensor_from_torch(tensor) + elif isinstance(tensor, float) or isinstance(tensor, int): + return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + + +def get_tensor_shape(tensor, op="GEMM"): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): + return tensor.shape + elif is_torch_tensor(tensor): + size = tensor.size() + if op == "CONV": + # PyTorch Tensors have shape NCHW + return (size[0], size[2], size[3], size[1]) + else: + return tuple(tensor.size()) + elif isinstance(tensor, float) or isinstance(tensor, int): + return (1,) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + + +_math_operation_value_map = {x.value: x for x in MathOperation} + + +def backend_math_operation(math_op: MathOperation): + if math_op.value not in _math_operation_value_map.keys(): + raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.") + return _math_operation_value_map[math_op.value] + + +def construct_backend_td(td: cutlass_cppgen.TileDescription, + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription: + mi = td.math_instruction + backend_mi = MathInstruction( + mi.instruction_shape, + mi.element_a, + mi.element_b, + mi.element_accumulator, + mi.opcode_class, + backend_math_operation(mi.math_operation) + ) + cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1] + return TileDescription(td.threadblock_shape, td.stages, td.warp_count, + backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler) + + +def td_from_profiler_op(op) -> TileDescription: + """ + Converts the profiler's TileDescription in ``op`` into the backend TileDescription + + :param op: profiler Operation + + :returns: backend TileDescription + :rtype: cutlass_cppgen.backend.TileDescription + """ + kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None + eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None + tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None + return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule) + + +def td_from_profiler_td(td: TileDescription) -> TileDescription: + """ + Converts the profiler's TileDescription into the backend TileDescription + + :param td: profiler TileDescription + :type td: cutlass_cppgen.TileDescription + + :returns: backend TileDescription + :rtype: cutlass_cppgen.backend.TileDescription + """ + return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None) + + +def to_camel_case(snake_str): + return "".join(x.capitalize() for x in snake_str.lower().split("_")) + + +def getattr_enum(obj, attr_name): + # The attr_name is under the snake_case + camel_attr = to_camel_case(attr_name) + if hasattr(obj, camel_attr): + return getattr(obj, camel_attr) + else: + raise Exception(f"Invalid option: {attr_name}") diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py new file mode 100644 index 0000000000000000000000000000000000000000..16f6a185040f4c2f6167c6191c9bee766a92b1b9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py @@ -0,0 +1,41 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +import importlib +from typing import Any + +def lazy_import(mod_name: str) -> Any: + class Lazy: + def __getattr__(self, name:str) -> Any: + module = importlib.import_module(mod_name) + return getattr(module, name) + + return Lazy() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..f53b1567978d17f2eaec0208d896aafb296f033f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py @@ -0,0 +1,196 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Profiler based on the cuda events +""" + +import re +import subprocess + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +import numpy as np + +from cutlass_cppgen import CUTLASS_PATH +from cutlass_cppgen.backend.library import DataTypeSize +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import is_numpy_tensor + + +class GpuTimer: + def __init__(self) -> None: + self.events = [ + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], + ] + + def start(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + (err,) = cuda.cuEventRecord(self.events[0], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + def stop(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + (err,) = cuda.cuEventRecord(self.events[1], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + pass + + def stop_and_wait(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + self.stop(stream) + if stream: + (err,) = cuda.cuStreamSynchronize(stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + else: + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + def duration(self, iterations=1): + err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1]) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + return duration / float(iterations) + + +class CUDAEventProfiler: + def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None: + self.arguments = op.run(*args, **kwargs) + self.operation = op.operation + self.warmup_iterations = warmup_iterations + self.iterations = iterations + self.timer = GpuTimer() + + # + # Cutlass Python Interface Profiler + # + + def __call__(self): + for _ in range(self.warmup_iterations): + self.operation.run(self.arguments) + + self.timer.start() + for _ in range(self.iterations): + self.operation.run(self.arguments) + + self.timer.stop_and_wait() + runtime = self.timer.duration(self.iterations) + return runtime + + # + # CUTLASS Profiler + # + + def run_cutlass_profiler(self): + alpha = 1.0 + beta = 1.0 + + profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler" + kernel_name = self.operation.procedural_name() + verification_providers = "device" + provider = "cutlass" + problem_size = self.arguments.problem_size + + if "cutlass3x" in kernel_name: + # cutlass3x generator only have column-major output + layout_name = self.operation.layout_name_3x() + if layout_name[-1] == "t": + new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"]) + problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k) + kernel_name = kernel_name.replace(layout_name, new_layout_name) + + batch_count = self.arguments.batch_count + + cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \ + f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \ + f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\ + f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}" + + result = subprocess.getoutput(cmd) + + m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) + runtime = float(m.group("runtime")) + + m = re.search(r"Bytes:\s+(?P\d+)", result) + bytes = int(m.group("bytes")) + + m = re.search(r"FLOPs:\s+(?P\d+)", result) + flops = int(m.group("flops")) + + # check if the problem size matches + assert bytes == self.bytes(problem_size, batch_count, beta) + assert flops == self.flops(problem_size, batch_count, beta) + + return runtime + + def bytes(self, problem_size, batch_count=1, beta=0.0): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + bytes = ( + (DataTypeSize[self.operation.A.element] * m // 8) * k + + (DataTypeSize[self.operation.B.element] * n // 8) * k + + (DataTypeSize[self.operation.C.element] * m // 8) * n + ) + + if beta != 0: + bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n + + bytes *= batch_count + + return bytes + + def flops(self, problem_size, batch_count=1, beta=0.0): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + flops_ = (m * n * k) * 2 * batch_count + + if beta != 0: + flops_ += m * n * batch_count * 2 + + return flops_ + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..534eef47d810eb9f17a9ba6dbbe2e0dff935eb3f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py @@ -0,0 +1,63 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os +import sys + +from . import conv2d_operation +from . import conv3d_operation +from . import emit_kernel_listing +from . import gemm_operation + +if '-m' not in sys.argv: + # Do not import generator when running python -m cutlass_library.generator to + # avoid double-import warnings + from . import generator + +from . import library +from . import manifest +from . import rank_2k_operation +from . import rank_k_operation +from . import symm_operation +from . import trmm_operation +# Make enum types from library.py accessible via cutlass_library.* +from .library import * + +# Set up `source` to point to the path containing the CUTLASS source. +# Check first if the path contains a `source` subdirectory -- this will +# be the case when the package has been installed via pip. Otherwise, +# default to the root of CUTLASS. +install_source_path = os.path.join(__path__[0], 'source') +if os.path.isdir(install_source_path): + source_path = install_source_path +else: + source_path = os.path.join(__path__[0], '../..') diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..b674463a2c5795be8610883c4dc98a1e7123a01b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py @@ -0,0 +1,621 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Conv2d kernels +""" + +import enum +import logging +import os.path +import shutil +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes +except ImportError: + from library import * + from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### + +# +class Conv2dOperation: + # + def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \ + group_mode = GroupMode.NoneGroup): + + self.operation_kind = OperationKind.Conv2d + self.arch = arch + self.tile_description = tile_description + self.conv_kind = conv_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + self.group_mode = group_mode + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + intermediate_type = '' + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.accumulator_type(): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = '' + + return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ + inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = self.tile_description.procedural_name() + + # grouped conv + if self.group_mode != GroupMode.NoneGroup: + group_conv_name = f"{GroupModeNames[self.group_mode]}_" + else: + group_conv_name = "" + + if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}" + else: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}" + + return SubstituteTemplate( + configuration_name, + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, + 'group_conv_name': group_conv_name + } + ) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.configuration_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv2dInstance: + def __init__(self): + # Emitter for CUTLASS 3 convolution operations + self.conv3x_emitter = EmitConv3xInstance() + self.template = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + self.template_group_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${group_mode}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + self.template_depthwise_direct_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>, + cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ${threadblock_output_shape_n}, + ${threadblock_output_shape_p}, + ${threadblock_output_shape_q}>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + cutlass::MatrixShape<${stride_r}, ${stride_s}>, + cutlass::MatrixShape<${dilation_r}, ${dilation_s}> + >::Kernel; +""" + + def arch_number_to_type(self, arch: int): + return f"cutlass::arch::Sm{arch}" + + def emit(self, operation): + _LOGGER.debug("*** EmitConv2dInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + if hasattr(operation, 'is_3x') and operation.is_3x: + _LOGGER.debug("*** CUTLASS 3 operation") + return self.conv3x_emitter.emit(operation) + + _LOGGER.debug("*** CUTLASS 2 operation") + + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], + 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + 'stride_support': StrideSupportTag[operation.stride_support], + 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \ + MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + } + + if operation.group_mode == GroupMode.NoneGroup: + _LOGGER.debug("*** group_mode=NoneGroup") + return SubstituteTemplate(self.template, values) + + elif operation.group_mode == GroupMode.Depthwise: + _LOGGER.debug("*** group_mode=Depthwise") + values['group_mode'] = GroupModeTag[operation.group_mode] + # Setup other template params + values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0]) + values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1]) + values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2]) + + values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3]) + + values['filter_shape_r'] = str(operation.tile_description.filter_shape[0]) + values['filter_shape_s'] = str(operation.tile_description.filter_shape[1]) + + values['stride_r'] = str(operation.tile_description.stride[0]) + values['stride_s'] = str(operation.tile_description.stride[1]) + + values['dilation_r'] = str(operation.tile_description.dilation[0]) + values['dilation_s'] = str(operation.tile_description.dilation[1]) + + return SubstituteTemplate(self.template_depthwise_direct_conv, values) + + else: + _LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode]) + values['group_mode'] = GroupModeTag[operation.group_mode] + return SubstituteTemplate(self.template_group_conv, values) + +################################################################################################### +# +# Generator functions for all layouts +# +################################################################################################### + +# +def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128): + _LOGGER.debug("*** GenerateConv2dTensorOp") + + for tile in tile_descriptions: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + + if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): + + # + output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ + if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ + else [tile.math_instruction.element_accumulator,] + + for output_type in output_types: + A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) + B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) + C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type]))) + + manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) + +class EmitConv2dIncludes: + '''Emit includes that are specific to the operation.''' + + def __init__(self): + self.includes = ['conv2d_operation.h'] + self.emitter_3x = EmitConv3xIncludes() + + def operation_is_3x(self, operation) -> bool: + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def emit(self, operation) -> str: + if self.operation_is_3x(operation): + return self.emitter_3x.emit(operation) + + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitConv2dConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) + + self.instance_emitter = EmitConv2dInstance() + self.includes_emitter = EmitConv2dIncludes() + + self.header_template = """ +/* + Generated by conv2d_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +""" + + self.instance_template = """ +${stub_begin} +${operation_instance} +// Derived class +struct ${operation_name} : + public ${operation_name}_base { }; +${stub_end} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.configuration_header = """ + +namespace cutlass { +namespace library { + +// Initialize all instances +void initialize_${configuration_name}(Manifest &manifest) { +""" + + self.configuration_instance = """${stub_begin} + using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< + ${operation_name}>; + + manifest.append(new cutlass::library::${operation_wrapper}< + Operation_${operation_name} + >( + "${operation_name}" + )); +${stub_end} +""" + + self.configuration_epilogue = "}\n" + + self.epilogue_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def operation_is_3x(self, operation): + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def __enter__(self): + """ + Open the configuration_file, and write the "header" C++ code to it. + + The "header" consists of a comment (that this is generated code, + so it should not be edited), and includes that are common + to all kinds of kernels. + """ + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + self.configuration_file = open(self.configuration_path, "w") + + self.configuration_file.write(SubstituteTemplate(self.header_template, { + 'configuration_name': self.configuration_name + })) + self.operations = [] + return self + + def emit(self, operation): + """ + Write three pieces of C++ code to the configuration_file + (that was opened by the __enter__ method above): + + 1. the header includes that are specific to the operation + (CUTLASS 2 vs. CUTLASS 3); + + 2. the "operation instance" (a "using" declaration ending in "_base"); and + + 3. the "operation name" (declaration and definition of a derived class + of the above operation instance). + + The "using" declaration turns a C++ class name, possibly namespace-qualified, + possibly also with angle brackets, into a C-style, easily demangled identifier. + """ + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit') + _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) + self.operations.append(operation) + + self.configuration_file.write(self.includes_emitter.emit(operation)) + + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = '#endif // 0' + + self.configuration_file.write(Template(self.instance_template).substitute({ + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'operation_instance': self.instance_emitter.emit(operation), + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + def __exit__(self, exception_type, exception_value, traceback): + """ + Write the rest of the C++ code to the configuration_file, and close the file. + + The "rest of the C++ code" has the following components. + + 1. Configuration header: Open the namespace(s), and open the definition + of the "initialize_${configuration_name}" registration function + that registers the operation with the Manifest. + ("Registration" helps turn C++ compile-time polymorphism + (via template parameters) into a run-time choice of parameters.) + + 2. Configuration instance: In the body of the registration function, + make a "using" declaration Operation_${operation_name} for the + operation type (which uses operation_name as its template argument). + Then, tell the manifest about the operation via a "manifest.append" call. + The argument of the call is a new instance of + "SomethingOperation" + (replace Something with a specific name). + + 3. Configuration epilogue: Close the definition of the registration function. + + 4. Epilogue template: Close the namespace(s). + """ + + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + + self.configuration_file.write(SubstituteTemplate(self.configuration_header, { + 'configuration_name': self.configuration_name + })) + + for operation in self.operations: + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = "#endif // 0" + + if operation.group_mode == GroupMode.Depthwise: + kernel_name = 'DirectConvolution' + operation_wrapper = 'DirectConv2dOperation' + else: + kernel_name = 'ImplicitGemmConvolution' + operation_wrapper = 'Conv2dOperation' + if self.operation_is_3x(operation): + kernel_name = 'ConvUniversalAdapter' + operation_wrapper = 'ConvOperation3x' + + self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'kernel_name': kernel_name, + 'operation_wrapper': operation_wrapper, + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + self.configuration_file.write(self.configuration_epilogue) + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + + +################################################################################################### +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..b96b6db74224e52bd90b6e184a62624475385352 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py @@ -0,0 +1,482 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Conv3d kernels +""" + +import enum +import logging +import os.path +import shutil +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes +except ImportError: + from library import * + from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### + +# +class Conv3dOperation: + # + def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + self.operation_kind = OperationKind.Conv3d + self.arch = arch + self.tile_description = tile_description + self.conv_kind = conv_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + intermediate_type = '' + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = '' + + return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \ + inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = "%dx%d_%dx%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + self.tile_description.stages + ) + + if self.stride_support == StrideSupport.Unity: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride" + else: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}" + + return SubstituteTemplate( + configuration_name, + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + } + ) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.configuration_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv3dInstance: + def __init__(self): + # Emitter for CUTLASS 3 convolution operations + self.conv3x_emitter = EmitConv3xInstance() + self.template = """ + // Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}< + ${element_a}, + cutlass::layout::TensorNDHWC, + ${element_b}, + cutlass::layout::TensorNDHWC, + ${element_c}, + cutlass::layout::TensorNDHWC, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + cutlass::arch::OpMultiplyAdd, + ${iterator_algorithm}, + ${stride_support} + >::Kernel; +""" + + def emit(self, operation): + _LOGGER.debug("*** EmitConv3dInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + if hasattr(operation, 'is_3x') and operation.is_3x: + _LOGGER.debug("*** CUTLASS 3 operation") + return self.conv3x_emitter.emit(operation) + + _LOGGER.debug("*** CUTLASS 2 operation") + + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], + 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + 'stride_support': StrideSupportTag[operation.stride_support] + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### +# +# Generator functions for all layouts +# +################################################################################################### + +# +def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128): + + for tile in tile_descriptions: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + + if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): + + # + output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ + if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ + else [tile.math_instruction.element_accumulator,] + + for output_type in output_types: + A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) + B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) + C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type]))) + + manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) + +class EmitConv3dIncludes: + '''Emit includes that are specific to the operation.''' + + def __init__(self): + self.includes = ['conv3d_operation.h'] + self.emitter_3x = EmitConv3xIncludes() + + def operation_is_3x(self, operation) -> bool: + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def emit(self, operation) -> str: + if self.operation_is_3x(operation): + return self.emitter_3x.emit(operation) + + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitConv3dConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) + + self.instance_emitter = EmitConv3dInstance() + self.includes_emitter = EmitConv3dIncludes() + + self.header_template = """ +/* + Generated by conv3d_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +""" + + self.instance_template = """ +${stub_begin} +${operation_instance} +// Derived class +struct ${operation_name} : + public ${operation_name}_base { }; +${stub_end} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.configuration_header = """ + +namespace cutlass { +namespace library { + +// Initialize all instances +void initialize_${configuration_name}(Manifest &manifest) { +""" + + self.configuration_instance = """${stub_begin} + using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< + ${operation_name}>; + + manifest.append(new cutlass::library::${operation_wrapper}< + Operation_${operation_name} + >( + "${operation_name}" + )); +${stub_end} +""" + + self.configuration_epilogue = "}\n" + + self.epilogue_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def operation_is_3x(self, operation): + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def __enter__(self): + """ + Open the configuration_file, and write the "header" C++ code to it. + + The "header" consists of a comment (that this is generated code, + so it should not be edited), and includes that are common + to both the CUTLASS 2 and the CUTLASS 3 cases. + """ + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + self.configuration_file = open(self.configuration_path, "w") + + self.configuration_file.write(SubstituteTemplate(self.header_template, { + 'configuration_name': self.configuration_name + })) + self.operations = [] + return self + + def emit(self, operation): + """ + Write three pieces of C++ code to the configuration_file + (that was opened by the __enter__ method above): + + 1. the header includes that are specific to the operation + (CUTLASS 2 vs. CUTLASS 3); + + 2. the "operation instance" (a "using" declaration ending in "_base"); and + + 3. the "operation name" (declaration and definition of a derived class + of the above operation instance). + + The "using" declaration turns a C++ class name, possibly namespace-qualified, + possibly also with angle brackets, into a C-style, easily demangled identifier. + """ + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit') + _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) + self.operations.append(operation) + + self.configuration_file.write(self.includes_emitter.emit(operation)) + + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = '#endif // 0' + + self.configuration_file.write(Template(self.instance_template).substitute({ + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'operation_instance': self.instance_emitter.emit(operation), + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + def __exit__(self, exception_type, exception_value, traceback): + """ + Write the rest of the C++ code to the configuration_file, and close the file. + + The "rest of the C++ code" has the following components. + + 1. Configuration header: Open the namespace(s), and open the definition + of the "initialize_${configuration_name}" registration function + that registers the operation with the Manifest. + ("Registration" helps turn C++ compile-time polymorphism + (via template parameters) into a run-time choice of parameters.) + + 2. Configuration instance: In the body of the registration function, + make a "using" declaration Operation_${operation_name} for the + operation type (which uses operation_name as its template argument). + Then, tell the manifest about the operation via a "manifest.append" call. + The argument of the call is a new instance of + "SomethingOperation" + (replace Something with a specific name). + + 3. Configuration epilogue: Close the definition of the registration function. + + 4. Epilogue template: Close the namespace(s). + """ + + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + + self.configuration_file.write(SubstituteTemplate(self.configuration_header, { + 'configuration_name': self.configuration_name + })) + + for operation in self.operations: + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = "#endif // 0" + + kernel_name = 'ImplicitGemmConvolution' + operation_wrapper = 'Conv3dOperation' + if self.operation_is_3x(operation): + kernel_name = 'ConvUniversalAdapter' + operation_wrapper = 'ConvOperation3x' + + self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'kernel_name': kernel_name, + 'operation_wrapper': operation_wrapper, + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + self.configuration_file.write(self.configuration_epilogue) + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + + +################################################################################################### +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..33d6da1a4675c0bbd07315717a7f5ba0ba0dc10c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py @@ -0,0 +1,250 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting CUTLASS >= 3 convolution kernels +""" + +import enum +import os.path +import shutil +import logging +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv3xInstance: + def __init__(self): + _LOGGER.debug("*** EmitConv3xInstance::__init__") + + # Define epilogue type first, so that the mainloop type + # can use it with StageCountAutoCarveout. + self.template = """ + +// CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}" +using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, + ${opcode_class_epi}, + ${mma_tile_shape}, // mma tile shape + ${cluster_shape}, // cluster shape + ${epi_tile_mn}, + ${element_accumulator}, + ${element_compute}, + ${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>, + ${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>, + ${epilogue_schedule} + // , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + +using ${operation_name}_mainloop = + typename cutlass::conv::collective::CollectiveBuilder< + ${arch}, + ${opcode_class_main}, + ${conv_kind}, // kFprop, kDgrad, or kWgrad + ${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>, + ${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>, + ${element_accumulator}, + ${mma_tile_shape}, // mma tile shape + ${cluster_shape}, // cluster shape + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + +using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>; + +// Unit tests call this "ConvKernel". +// Conv operator ${operation_name} +using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal< + ${operation_name}_problem_shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler} + >; +""" + + def arch_number_to_type(self, arch: int) -> str: + return f"cutlass::arch::Sm{arch}" + + def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: + mma_m = cta_m + mma_n = cta_n + mma_k = cta_k + + if operation.arch >= 100: + # MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where + # mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version. + # If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated, + # otherwise 1sm kernel is allocated. + cta_m_per_mma_instruction = 1 + if "2sm" in operation.procedural_name() : + cta_m_per_mma_instruction = 2 + elif "1sm" in operation.procedural_name() : + cta_m_per_mma_instruction = 1 + elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 : + cta_m_per_mma_instruction = 2 + mma_m = cta_m * cta_m_per_mma_instruction + + # For all three kinds of convolutions, the tile shape's K mode + # differs from GEMM in that needs to be wrapped in a Shape. + # For Wgrad convolutions specifically, + # the N tile shape also needs to be wrapped in a Shape. + m_template = 'cute::_${mma_m}' + if operation.conv_kind == ConvKind.Wgrad: + n_template = 'cute::Shape' + else: + n_template = 'cute::_${mma_n}' + k_template = 'cute::Shape' + + mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + values = { + 'mma_m': mma_m, + 'mma_n': mma_n, + 'mma_k': mma_k + } + return Template(mma_tile_shape_template).substitute(values) + + def cluster_shape(self, operation) -> str: + m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)' + n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)' + k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)' + cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + values = { + 'cluster_shape_m': operation.tile_description.cluster_shape[0], + 'cluster_shape_n': operation.tile_description.cluster_shape[1], + 'cluster_shape_k': operation.tile_description.cluster_shape[2], + } + return Template(cluster_shape_template).substitute(values) + + def stage_count(self, operation) -> str: + # stages == 0 tells builder to pick the number of stages automatically + namespace_prefix = 'cutlass::conv::collective::' + if operation.tile_description.stages > 0: + return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>" + else: + return f"{namespace_prefix}StageCountAutoCarveout" + + def emit(self, operation) -> str: + _LOGGER.debug("*** EmitConv3xInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + # Identify the operation as CUTLASS 3 by its is_3x field + if (not hasattr(operation, 'is_3x')) or (not operation.is_3x): + raise RuntimeError("operation must be a CUTLASS 3 operation") + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + + cta_m, cta_n, cta_k = tile_shape + # account for static/dynamic cluster shapes + if operation.arch >= 100: + cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m + cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n + + warp_count = operation.tile_description.warp_count + epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule] + + # KernelScheduleTag and TileSchedulerTag both hard-code the + # namespace qualification of KernelScheduleAuto as + # "cutlass::gemm::collective::" (unless the tag is 'void'). + # + # For TileSchedulerTag, this namespace is fine, since CUTLASS 3 + # convolutions use the same tile schedulers (from the same + # cutlass::gemm::collective namespace) as GEMMs. + kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::') + tile_scheduler = TileSchedulerTag[operation.tile_scheduler] + opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'align_a': int(operation.A.alignment), + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'align_b': int(operation.B.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'align_c': int(operation.C.alignment), + 'element_d': DataTypeTag[operation.D.element], + 'layout_d': LayoutTag[operation.D.layout], + 'align_d': int(operation.D.alignment), + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': opcode_class, + 'arch': self.arch_number_to_type(operation.arch), + 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k), + 'cluster_shape': self.cluster_shape(operation), + 'opcode_class_epi': opcode_class_epi, + 'opcode_class_main': opcode_class_main, + 'epi_tile_mn': epi_tile_mn, + 'stages': self.stage_count(operation), + 'kernel_schedule': kernel_schedule, + 'epilogue_schedule': epilogue_schedule, + 'tile_scheduler': tile_scheduler, + 'element_compute': DataTypeTag[operation.element_compute] + } + return Template(self.template).substitute(values) + +class EmitConv3xIncludes: + def __init__(self): + _LOGGER.debug("*** EmitConv3xIncludes::__init__") + self.includes = ['conv_operation_3x.hpp', + 'cutlass/conv/device/conv_universal_adapter.hpp', + 'cutlass/conv/kernel/conv_universal.hpp', + 'cutlass/conv/collective/collective_builder.hpp', + 'cutlass/epilogue/collective/collective_builder.hpp'] + + def emit(self, operation) -> str: + _LOGGER.debug("*** EmitConv3xIncludes::emit") + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe52eb587ab1b5e4595739be5790151b00e0a70 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py @@ -0,0 +1,868 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# +# +# \brief Generates the CUTLASS kernel listing with kernel filtering +# + +# + +############################################################################### +# Example usage: +# generator.py --operations all --generator-target kernel_listing \ +# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports +############################################################################### + +import collections +import csv +import json +import math +import os + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +audit_csv_fields = [ + "KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD", + "Layout_A", "Layout_B", "Layout_C", "Layout_D", + "Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D", + "1SM/2SM", + "StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types", + "Test Counts" +] + +audit_csv_runtime_fields = [ + "KerneIndex", "KernelName", + "Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K", + "Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K", + "M", "N", "K", "L", "Alpha_val", "Beta_val", + "Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled" +] + +def hash_cutlass_string(input_string): + mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + + # Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + output = re.sub(mma_cluster_shape_pattern, "", input_string) + + return output + +def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b): + # Define a dictionary mapping the detected types to runtime values + datatype_map = { + 'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + } + + # Regular expression to detect all the keys in datatype_map + pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')') + + # Replace detected patterns using the dictionary + updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name) + + return updated_kernel_name + +# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k. +def get_kernel_features(operation, kernel_name, + dynamic_datatype, runtime_input_datatype): + numcta_inst = "2sm" if "2sm" in kernel_name else "1sm" + math_inst = operation.tile_description.math_instruction + + if dynamic_datatype: + dtype_name_A = runtime_input_datatype[0] + dtype_name_B = runtime_input_datatype[1] + else: + dtype_name_A = DataTypeNames[operation.A.element] + dtype_name_B = DataTypeNames[operation.B.element] + + layout_name_A = ShortLayoutTypeNames[operation.A.layout] + layout_name_B = ShortLayoutTypeNames[operation.B.layout] + layout_name_C = ShortLayoutTypeNames[operation.C.layout] + layout_name_D = ShortLayoutTypeNames[operation.D.layout] + + scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void + scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void) + audit_vals = [ + "BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM", + kernel_name, + dtype_name_A, + dtype_name_B, + DataTypeNames[operation.C.element], + DataTypeNames[operation.tile_description.math_instruction.element_accumulator], + DataTypeNames[operation.element_epilogue], + DataTypeNames[operation.D.element], + DataTypeNames[scale_factor_D_type], + DataTypeNames[scale_factor_A_type], + layout_name_A, + layout_name_B, + layout_name_C, + layout_name_D, + str(operation.A.alignment), + str(operation.B.alignment), + str(operation.C.alignment), + str(operation.D.alignment), + numcta_inst, + "Y" if 'stream_k' in kernel_name else "N", + ] + return audit_vals + +# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta. +def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster): + math_inst = operation.tile_description.math_instruction + audit_vals = [ + str(math_inst.instruction_shape[0]), + str(math_inst.instruction_shape[1]), + str(math_inst.instruction_shape[2]), + str(operation.tile_description.threadblock_shape[0]), + str(operation.tile_description.threadblock_shape[1]), + str(operation.tile_description.threadblock_shape[2]), + str(operation.tile_description.cluster_shape[0]), + str(operation.tile_description.cluster_shape[1]), + str(operation.tile_description.cluster_shape[2]), + str(cluster_shape[0]), + str(cluster_shape[1]), + str(cluster_shape[2]), + str(fallback_cluster_shape[0]), + str(fallback_cluster_shape[1]), + str(fallback_cluster_shape[2]), + str(problem_shape[0]), + str(problem_shape[1]), + str(problem_shape[2]), + str(problem_shape[3]), + str(alpha), + str(beta), + "Y" if dynamic_datatype else "N", + "Y" if dynamic_cluster else "N", + ] + return audit_vals + + +def _getSubOperationType(kernel): + + if kernel.operation_kind == OperationKind.Gemm: + return GemmKindNames[kernel.gemm_kind] + elif kernel.operation_kind == OperationKind.Conv2d: + return "conv_" + ConvKindNames[kernel.conv_kind] + elif kernel.operation_kind == OperationKind.Syrk: + return "syrk_" + SyrkKindNames[kernel.syrk_kind] + elif kernel.operation_kind == OperationKind.Trmm: + return "trmm_" + TrmmKindNames[kernel.trmm_kind] + elif kernel.operation_kind == OperationKind.Symm: + return "symm_" + SymmKindNames[kernel.symm_kind] + else: + raise Exception("Unsupported kernel type") + +def _get_inst_shape(math_instruction): + return "".join(str(x) for x in math_instruction.instruction_shape) + +def _is_simt_inst(math_instruction): + return _get_inst_shape(math_instruction) in ["111","114"] + +def _getInstType(input_precision, accumulate_precision, math_instruction): + + # inst_shape + inst_shape = _get_inst_shape(math_instruction) + + # input precision + if input_precision == "fp32" and inst_shape != "111": + inp = "tf32" + else: + inp = input_precision + + # Handle SIMT op types first + if _is_simt_inst(math_instruction): + + simt_input_precision_to_inst = { + "fp32": "FFMA", + "fp64": "DFMA", + "fp16": "HFMA", + "int8": "IDP4A", + } + inst = simt_input_precision_to_inst[input_precision] + + else: # Tensor op instructions + + if accumulate_precision == "cf64": + fp64_acc_map = { + MathOperation.multiply_add_complex_gaussian : "gz", + MathOperation.multiply_add_complex : "z", + } + acc = fp64_acc_map[math_instruction.math_operation] + else: + tensor_op_acc_map = { + "fp32" : "s", + "cf32" : "s", + "fp16" : "h", + "int32": "i", + "fp64" : "d", + } + acc = tensor_op_acc_map[accumulate_precision] + + inst = "{}{}{}".format(acc, inst_shape, inp) + + return inst +# TODO: Computes FLOps/Bytes for GEMM - revisit for conv +def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1): + assert not (batch_count > 1 and num_groups > 1) + + # TODO: adjust for sparsity + gmem_bytes = ( + (DataTypeSize[operation.A.element] * m // 8) * k + + (DataTypeSize[operation.B.element] * n // 8) * k + + (DataTypeSize[operation.C.element] * m // 8) * n + ) + + # TODO: complex-valued support + flops = 2 * (m * n * k) + + if bool(beta): + gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n + flops += 2 * m * n + + multiplier = max(batch_count, num_groups) + gmem_bytes *= multiplier + flops *= multiplier + + return flops / gmem_bytes + +def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode + ): + # For functional testing, we prefer to run reference computing on device if any + reference_device_archs = ["100a", "103a"] + run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False + profiler_flags_for_verification = "device" if run_reference_on_device else "host" + + # beta values for L0 and L1 + # TODO: randomize beta values for wider coverage + beta_values = [0.5] + + is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"]) + + is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch + + if (mode == "functional_L0") and is_supported_arch: + problem_waves = [0.5, 1.25, 2.5] + + # + # Dense Gemm + # + + sm100_mma_data_type_general = [ + 'gemm_f16_f16_f16_f16_f16', + 'gemm_f16_f16_f16_void_f16', + #'gemm_f16_f16_f32_f16_f16', + 'tf32gemm_f32_f32_f32_f32_f32', + 'bf16gemm_f32_f32_f32_f32_f32', + ] + + exclude_archs = arch not in ("103a") + if exclude_archs: + sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8') + + sm100_mma_data_type_runtime_dtype = [ + 'gemm.*f4_f4_f32_f32_f32', + 'gemm.*f6_f6_f32_f32_f32', + 'gemm.*f8_f8_f32_f32_f32', + ] + + sm100_mma_cluster_size = [ + '8x1x1', + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + # Restrict to two layouts to reduce L0 build and test time. + sm100_mma_layouts = [ + 'tnt', + 'ntn' + ] + + # regex list must be in kernel procedural name order + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + # + # Block Scale Gemm + # + + block_scaled_data_type = [ + # runtime datatypes + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2', + 'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2', + #'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', + ] + + block_scaled_tile_k = ['x128_', 'x256_'] + + sm103_block_scaled_data_type = [ + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + ] + + sm103_block_scaled_tile_k = ['x768_'] + + block_scaled_cluster_size = [ + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + block_scaled_layouts = ['tnt'] + # regex list must be in kernel procedural name order + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + sm103_block_scaled_prefetch_policy = ['tmapf'] + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + + if arch in ["100a", "100f"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" + elif arch in ["101a", "101f", "110a", "110f"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" + elif arch in ["103a"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})|" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" + elif arch in ["120a", "120f", "121a", "121f"]: + + # blockscaled sm120_mma kernels + blockscaled_sm120_mma_kernel_cta_tiles = [ + [ '128x128' ] + ] + + # Restrict to two layouts to reduce L0 build and test time. + blockscaled_sm120_mma_layouts = [ 'tn' ] + filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*" + + problem_waves = [0.5, 1.25, 2.5] + + kernel_filter = f"({filter_regex_blockscaled_sm120_mma})" + else: + error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f" + raise Exception(error_message) + + elif mode == "functional_L1": + sm100_mma_cluster_size = [ + '0x0x1' # dynamic cluster + ] + # Restrict to two layouts to reduce L1 build and test time. + sm100_mma_layouts = ['tnt', 'ntn'] + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2', + 'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + 'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', + ] + + sm103_block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + ] + + block_scaled_cluster_size = ['0x0x1'] + block_scaled_layouts = ['tnt'] + + # regex list must be in kernel procedural name order + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" + # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times + sm120_mma_kernel_cta_tiles = [ + # h1688, s1688, i16832, i8816 + [ '256x128' ], + # d884, c1688, + [ '128x128' ], + # c1688, z884 + [ '128x64' ], + # gz884 + [ '64x64' ] + ] + + # sm120 MMA instruction shapes, planar complex type excluded as they are not required + sm120_mma_instruction_shapes = [ + [ 'h1688gemm_(?!planar_complex)', + 's1688gemm_f16', + 's1688gemm_bf16', + 's1688gemm_tf32', + 'i16832gemm', + 'i8816gemm' ], + [ 'd884gemm', 'c1688tf32gemm' ] , + [ 'c1688gemm', + 'z884gemm' ], + [ 'gz884gemm'] + ] + + # It's not pretty, but not sure why different instructions support different tile sizes. + filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*" + filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*" + filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*" + filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*" + + filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})" + + problem_waves = [0.5, 1.25, 2.5] + + if arch in ["120a", "120f", "121a", "121f"]: + kernel_filter = f"({filter_regex_sm120_mma})" + else: + kernel_filter = f"({filter_regex_sm100_mma})" + else: + raise ValueError() + + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + + audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv") + + audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv") + + kernel_filter_re = re.compile(kernel_filter) + testcase_counter = 0 + kernels_emitted = 0 + kernels_total = 0 + + perf_json_list = [] + kernel_name_set = set() + + testlist_csv_fields = ["testcase", "metadata"] + testlist_csv_rows = [] + auditlist_csv_map = {} + auditlist_csv_params_map = {} + + kernel_features = {} + + for cc in manifest.operations[OperationKind.Gemm].keys(): + for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items(): + assert(len(operation_l) == 1) + kernels_total += 1 + if len(kernel_filter_re.findall(kernel_name)) == 0: + continue + # Only test f16 I/O void C kernels in void C kernel set + # Exception: Use void C kernels for more accurate perf testing + if '_void_' in kernel_name and 'perf_' not in mode: + if 'f16_f16_f16_void_f16' not in kernel_name : + continue + + kernels_emitted += 1 + kernel_name_set.add(kernel_name) + hashed_kernel_name = hash_cutlass_string(kernel_name) + operation = operation_l[0] + + dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0 + or operation.tile_description.cluster_shape[1] == 0) + + dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name + + runtime_input_datatypes = [None] + + if dynamic_datatype: + if "f4_f4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "f4_f6" in kernel_name: + runtime_input_datatypes = [['e2m1','e3m2']] + elif "f4_f8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "f6_f4" in kernel_name: + runtime_input_datatypes = [['e3m2','e2m1']] + elif "f6_f6" in kernel_name: + runtime_input_datatypes = [['e3m2','e3m2']] + elif "f6_f8" in kernel_name: + runtime_input_datatypes = [['e3m2','e4m3']] + + elif "f8_f4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "f8_f6" in kernel_name: + runtime_input_datatypes = [['e4m3','e3m2']] + elif "f8_f8" in kernel_name: + runtime_input_datatypes = [ + # mask out those not covered in statically encoded test cases + # ['e5m2','e4m3'], + # ['e4m3','e5m2'], + ['e4m3','e4m3'] + ] + + # block scaled kernels + elif "ue8m0xf4_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue4m3xf4_ue4m3xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue8m0xf4_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m3']] + elif "ue8m0xf4_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "ue8m0xf6_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m1']] + elif "ue8m0xf6_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m3']] + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "ue8m0xf8_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m3']] + elif "ue8m0xf8_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e4m3','e4m3']] + + if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + profiler_flags_for_verification = "host" + + # reduce L1 test runtime if reference kernel is not running on device. + if mode == "functional_L1" and profiler_flags_for_verification == "host" : + problem_waves = [0.5, 2.5] + + + if dynamic_cluster: + if mode == "functional_L0": + runtime_cluster_shapes = [[1,1,1], [2,2,1]] + else: + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]] + # reduce L1 test runtime if reference kernel is not running on device. + if profiler_flags_for_verification == "host": + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]] + cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape + else: + runtime_cluster_shapes = [operation.tile_description.cluster_shape] + cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0]) + cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1]) + cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2]) + + alignment_a = operation.A.alignment + alignment_b = operation.B.alignment + alignment_c = operation.C.alignment + alignment_ab_max = max(alignment_a, alignment_b) + + layout3x = operation.layout_name_3x() + data_types = operation.datatype_name_3x() + + ctas_per_mma_instruction = 1 + if '_2sm' in kernel_name: + ctas_per_mma_instruction = 2 + valid_cluster_shapes = [] + + # Remove any cluster shapes that have cluster_m that is not divisible by 2 + for cs in runtime_cluster_shapes: + if cs[0] % 2 == 0: + valid_cluster_shapes.append(cs) + runtime_cluster_shapes = valid_cluster_shapes + + kernel_problem_waves = problem_waves + if mode == "functional_L0" or mode == "functional_L1": + # for functional testing, we want to perturb just a little from even shapes + # large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not + # -16 ensures that we are TMA aligned even for FP8/Int8 + min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max + max_k = (cta_tile_shape_k*8) - alignment_ab_max + problem_shapes_k = [min_k, max_k] + sm_count = 16 + swizzle_sizes = [0] + # Larger k and less than half wave trigger streamk +separate reduction case to be generated + if 'stream_k' in kernel_name: + problem_shapes_k = [max_k, cta_tile_shape_k*32] + kernel_problem_waves = [0.125, 1.25, 2.5] + else: + raise ValueError + + if "void" in kernel_name: + beta_values = [0] + + alignment_shift_m = max(alignment_c, alignment_a) + alignment_shift_n = max(alignment_c, alignment_b) + + is_first_line = True + for index_waves, waves in enumerate(kernel_problem_waves): + for index_k, k in enumerate(problem_shapes_k): + for beta in beta_values: + for cluster_shape in runtime_cluster_shapes: + for runtime_input_datatype in runtime_input_datatypes: + for swizzle_size in swizzle_sizes: + grid_size = waves * sm_count + cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape) + if cluster_shape_m >= cluster_shape_n: + grid_m = cluster_shape_m + grid_n = grid_size / grid_m + grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1) + else: + grid_n = cluster_shape_n + grid_m = grid_size / grid_n + grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1) + + verification_required = False + if mode == "functional_L0" or mode == "functional_L1": + if '_void_' not in kernel_name: + verification_required = True + + m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max) + n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max) + k = int(k) + + # For functional testing, we want to perturb just a little from even shapes. + # Only do this if the perturbation does not cause one of the dimensions of the + # problem size to go to zero. This can occur for blockscaling kernels for which + # the alignment requirements for A and B can be quite large (e.g., 256). + if m > alignment_shift_m: + m -= alignment_shift_m + if n > alignment_shift_n: + n -= alignment_shift_n + + if '_n32t32_' in kernel_name: + continue + batch_count = 1 + if mode == "functional_L0" or mode == "functional_L1" : + if index_waves == 0 and index_k == 0 : + batch_count = 3 if mode == "functional_L0" else 5 + gemm_op = "gemm" + + grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind) + num_groups = 1 + if grouped: + gemm_op = "grouped_gemm" + num_groups = 3 # small to limit test time in host block-scaled reference kernels + batch_count = 1 + elif "bstensorop" in kernel_name: + gemm_op = "block_scaled_gemm" + elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + gemm_op = "blockwise_gemm" + + problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)] + + assert m > 0 and n > 0 and k > 0 + + # Emit per-testcase metadata for perf testing usage, eventually in perf database + metadata_dict = { + "input_params": { + 'problem_size_category' : problem_size_category, + 'operation' : _getSubOperationType(operation), + 'datatype' : data_types, + 'layout' : layout3x, + 'm' : m, + 'n' : n, + 'k' : k, + 'beta' : beta, + 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups) + }, + "runtime_params": { + 'ctas_per_mma_instruction' : ctas_per_mma_instruction, + 'tilesize_m' : cta_tile_shape_m, + 'tilesize_n' : cta_tile_shape_n, + 'tilesize_k' : cta_tile_shape_k, + 'cluster_shape_m' : cluster_shape_m, + 'cluster_shape_n' : cluster_shape_n, + } + } + + cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m + cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n + cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k + + + if dynamic_datatype: + runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype) + metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a + metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b + + testcase_metadata = [ + f"cutlass_profiler --operation={gemm_op}" + + (f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") + + f" --error-on-no-match --error-if-nothing-is-profiled" + + f" --kernels={kernel_name}" + + f" --m={str(m)}" + + f" --n={str(n)}" + + f" --k={str(k)}" + + (f" --num_groups={str(num_groups)}" if grouped else "") + + f" --cluster_m={str(cluster_shape_m)}" + + f" --cluster_n={str(cluster_shape_n)}" + + f" --cluster_k={str(cluster_shape_k)}" + + f" --cluster_m_fallback={str(cluster_m_fallback)}" + + f" --cluster_n_fallback={str(cluster_n_fallback)}" + + f" --cluster_k_fallback={str(cluster_k_fallback)}" + + f" --beta={str(beta)}" + + ("" if grouped else f" --batch_count={str(batch_count)}") + + f" --swizzle_size={str(swizzle_size)}" + + f" --verification-required={str(verification_required).lower()}" + ] \ + + output_dynamic_datatype = dynamic_datatype + if output_dynamic_datatype: + testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" + + f" --runtime_input_datatype_b={runtime_datatype_b}") + + testcase_metadata.append(json.dumps(metadata_dict)) + testlist_csv_rows.append(testcase_metadata) + testcase_counter += 1 + + alpha = 1.0 + + if dynamic_datatype: + hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b) + + # If kernel_name is new, initialize its feature set with defaults + if hashed_kernel_name not in kernel_features: + kernel_features[hashed_kernel_name] = { + "is_support_dynamic_cluster": False, + "is_support_dynamic_datatype": False, + } + + # Update features for the hashed kernel name + kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster + kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype + + if hashed_kernel_name not in auditlist_csv_params_map: + auditlist_csv_params_map[hashed_kernel_name] = [] + + audit_row_params = get_kernel_params( + operation, + hashed_kernel_name, + (cluster_shape_m, cluster_shape_n, cluster_shape_k), + (cluster_m_fallback, cluster_n_fallback, cluster_k_fallback), + (m, n, k, batch_count), + alpha, beta, + dynamic_datatype, dynamic_cluster + ) + + auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params) + + if hashed_kernel_name not in auditlist_csv_map: + audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype) + auditlist_csv_map[hashed_kernel_name] = audit_row + + with open(outfile_name, 'w') as testlist_csv: + csv_writer = csv.writer(testlist_csv, delimiter=',') + csv_writer.writerow(testlist_csv_fields) + csv_writer.writerows(testlist_csv_rows) + + with open(audit_file_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_fields) + for hashed_kernel_name, row in auditlist_csv_map.items(): + # Append the dynamic features as "Y" or "N" + dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N" + dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N" + test_count = len(auditlist_csv_params_map[hashed_kernel_name]) + csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count]) + + with open(audit_file_params_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_runtime_fields) + for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1): + for i, row in enumerate(rows): + if i == 0: + csv_writer.writerow([kernel_index, hashed_kernel_name] + row) + else: + csv_writer.writerow(["", ""] + row) + + print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.") + + # Generate a newline separated list of kernel filters + assert(len(kernel_name_set) == kernels_emitted) + output_filter_enabled = True + if output_filter_enabled: + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name, "w") as file: + kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set)) + for kernel_name in kernel_name_set: + file.write(kernel_name + "\n") + + # Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together. + if mode == "functional_L0" or mode == "functional_L1": + # Sort the .csv file + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + with open(outfile_name) as file: + data = file.readlines() + data.sort() + with open(outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + # Sort the kernel list + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name) as file: + data = file.readlines() + data.sort() + with open(kernel_filter_outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2449e769303b738212cdcd896c9f2793ca2632 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py @@ -0,0 +1,1613 @@ + +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting GEMM kernels +""" + +import collections +import enum +import functools +import logging +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### +# +# Data structure modeling a GEMM operation +# +################################################################################################### + +# +class GemmOperation: + # + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, + kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, + tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False, + ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None, + ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None): + + kinds_3x = { + GemmKind.Universal3x, + GemmKind.SparseUniversal3x, + GemmKind.BlockScaledUniversal3x, + GemmKind.GroupedUniversal3x, + GemmKind.GroupedBlockScaledUniversal3x, + GemmKind.BlockwiseUniversal3x, + GemmKind.GroupedBlockwiseUniversal3x, + } + self.is_3x = gemm_kind in kinds_3x + self.prefix = "3x" if self.is_3x else "" + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.tile_description = tile_description + self.gemm_kind = gemm_kind + self.A = A + self.B = B + self.C = C + self.D = D + + if is_block_scaled(gemm_kind): + self.ScaleFactorA = ScaleFactorA + self.ScaleFactorB = ScaleFactorB + self.ScaleFactorD = ScaleFactorD["tensor"] + self.ScaleFactorVectorSize = ScaleFactorD["vector_size"] + + if is_blockwise(gemm_kind): + self.ScaleFactorMVecSize = ScaleFactorMVecSize + self.ScaleFactorNVecSize = ScaleFactorNVecSize + self.ScaleFactorKVecSize = ScaleFactorKVecSize + + if self.D == None: + self.D = self.C + + if not self.is_3x: + assert(kernel_schedule == KernelScheduleType.ScheduleAuto) + assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto) + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + + if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination: + self.epilogue_functor = EpilogueFunctor3x.LinearCombination + + self.swizzling_functor = swizzling_functor + self.tile_scheduler = tile_scheduler + + # Only enable mixed input mode and mixed input shuffle for Hopper + self.mixed_input_mode = None + if self.is_mixed_input() and self.arch >= 90 and self.arch < 100: + self.mixed_input_mode = mixed_input_mode + self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and', + MathOperation.multiply_add_fast_accum: 'fastaccum', + } + + tensor_ops = [ + OpcodeClass.TensorOp, + OpcodeClass.WmmaTensorOp, + OpcodeClass.SparseTensorOp, + OpcodeClass.BlockScaledTensorOp, + ] + + is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops + + if is_tensor_op: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else "" + + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + short_math_name = self.short_math_name() if not self.is_3x else "" + + return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) + + # Generates a string representing the MMA instruction. + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + element_sfa = "" + element_sfb = "" + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.is_mixed_input(): + extended_name = "${core_name}_${element_a}_${element_b}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name + elif is_blockwise(self.gemm_kind): + extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}" + element_sfa = DataTypeNames[self.accumulator_type()] + element_sfb = DataTypeNames[self.accumulator_type()] + else: + extended_name = "${core_name}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name + if self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name += "_${element_a}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_sfa' : element_sfa, + 'element_b': DataTypeNames[self.B.element], + 'element_sfb' : element_sfb, + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def mixed_input_mode_name(self): + mode_name_mapping = { + MixedInputMode.ConvertOnly: "_cvt", + MixedInputMode.ScaleOnly: "_scl", + MixedInputMode.ScaleWithZeroPoint: "_sclzr" + } + mode_name = mode_name_mapping.get(self.mixed_input_mode, "") + if self.mixed_input_shuffle: + mode_name = mode_name + "_shfl" + return mode_name + + def extended_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element], + core_name = self.core_name()) + + if is_block_scaled(self.gemm_kind): + d_type_names = DataTypeNames[self.D.element] + + if self.ScaleFactorD.element != DataType.void: + d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names + + extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_sfa = DataTypeNames[self.ScaleFactorA], + element_a = DataTypeNames[self.A.element], + element_sfb = DataTypeNames[self.ScaleFactorB], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = d_type_names, + core_name = self.core_name()) + + if is_blockwise(self.gemm_kind): + d_type_names = DataTypeNames[self.D.element] + + extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_sfa = DataTypeNames[self.accumulator_type()], + element_a = DataTypeNames[self.A.element], + element_sfb = DataTypeNames[self.accumulator_type()], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = d_type_names, + sfvec_m_size = self.ScaleFactorMVecSize, + sfvec_n_size = self.ScaleFactorNVecSize, + sfvec_k_size = self.ScaleFactorKVecSize, + core_name = self.core_name()) + + if self.mixed_input_mode != None: + extended_name = extended_name + self.mixed_input_mode_name() + return extended_name + + def datatype_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element]) + return datatype_name + + # Generates a short string representing the AB layout tags (e.g. nt or tn) + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) + def layout_name_3x(self): + if self.is_complex() or self.is_planar_complex(): + return "{}{}{}".format( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) + else: + return "{}{}{}".format( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ShortLayoutTypeNames[self.C.layout]) + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name_3x(self): + return KernelScheduleSuffixes[self.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name_3x(self): + + if is_block_scaled(self.gemm_kind): + if self.ScaleFactorD.element != DataType.void: + return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout] + + return EpilogueScheduleSuffixes[self.epilogue_schedule] + + # Generate a short string representing the operation class + def opcode_class_name(self): + return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + def get_collective_tile_shape(self): + """ + Get the tile shape passed to the collective builder. + On Blackwell, this is different than the operation.tile_description.tile_shape. + """ + is_sm100_kernel = (self.arch == 100 or self.arch == 103) + if not is_sm100_kernel: + return self.tile_description.tile_shape + + opcode_class_main = self.tile_description.math_instruction.opcode_class + instruction_shape = self.tile_description.math_instruction.instruction_shape + tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape + if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]: + tile_shape_m = instruction_shape[0] + tile_shape_n = instruction_shape[1] + return (tile_shape_m, tile_shape_n, tile_shape_k) + + # Generates the full kernel function name + def procedural_name(self): + return self._procedural_name + + @functools.cached_property + def _procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + if self.arch >= 90: + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}" + tile_shape = self.get_collective_tile_shape() + return kernel_name_template.format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = self.extended_name_3x(), + ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "", + cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), + l = self.tile_description.stages, + s = self.layout_name_3x(), + al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], + k = self.kernel_schedule_name_3x(), + e = self.epilogue_schedule_name_3x()) + else: + threadblock = self.tile_description.procedural_name() + return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( + p = self.prefix, + op = opcode_class_name, + ex = self.extended_name(), + tb = threadblock, + l = self.layout_name(), + a = str(max(self.A.alignment, self.B.alignment))) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + + def __hash__(self): + return hash(self.configuration_name()) + + def __eq__(self, other): + return self.configuration_name() == other.configuration_name() + +################################################################################################### +# +# Data structure modeling a grouped GEMM operation +# +################################################################################################### + +# +class GroupedGemmOperation(GemmOperation): + # + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + scheduler_mode = GroupScheduleMode.Device): + super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor, swizzling_functor) + + self.scheduler_mode = scheduler_mode + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + base = super().procedural_name() + return SubstituteTemplate( + base + "_schedule${schedule}", + { + 'schedule': ShortGroupScheduleModeNames[self.scheduler_mode] + }) + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::Gemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + self.gemm_complex_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${transform_a}, + ${transform_b}, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_complex_template if operation.is_complex() else self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + +class EmitSparseGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + + +# +class EmitGemmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > +""" + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_interleaved = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor + } + + if operation.A.layout in transpose_layouts.keys() and \ + operation.B.layout in transpose_layouts.keys() and \ + operation.C.layout in transpose_layouts.keys(): + + instance_layout_A = transpose_layouts[operation.A.layout] + instance_layout_B = transpose_layouts[operation.B.layout] + instance_layout_C = transpose_layouts[operation.C.layout] + + gemm_template = self.gemm_template + else: + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + gemm_template = self.gemm_template_interleaved + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(gemm_template, values) + + +################################################################################################### + +class EmitGemmUniversal3xInstance: + ''' Responsible for emitting a CUTLASS 3.x template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + "cutlass/detail/blockwise_scale_layout.hpp", + ] + self.builtin_epilogue_functor_template = \ +"""${epilogue_functor}< + ${element_d}, + ${element_epilogue}, + ${element_c}, + ${element_epilogue} + >""" + + self.gemm_template = """ + +using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class_epi}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${epi_tile_mn}, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule}, + ${epilogue_functor} + >::CollectiveOp; + +${mixed_dtype_prepare_code} +${blockwise_prepare_code} + +using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class_main}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + ${problem_shape}, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + +// Define named type +struct ${operation_name} : + public ${operation_name}_base { }; + +""" + # + def instance_template(self): + return """ +${compile_guard_start} + { + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + } +${compile_guard_end} +""" + + + def emit_block_scale_epilogue_functor(self, operation): + block_scaled_template = """ + ${epilogue_functor}< + ${epi_vs}, + ${element_d}, + ${element_accumulator}, + ${element_sfd}, + ${layout_sfd}, + ${element_c}, + ${element_scalar} + > + """ + block_scaled_values = { + 'epi_vs' : str(operation.ScaleFactorVectorSize), + 'element_d': str(DataTypeTag[operation.D.element]), + 'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]), + 'layout_sfd': LayoutTag[operation.ScaleFactorD.layout], + 'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor], + 'element_accumulator': str(DataTypeTag[operation.accumulator_type()]), + 'element_scalar': str(DataTypeTag[operation.accumulator_type()]), + 'element_c': str(DataTypeTag[operation.C.element]), + } + return SubstituteTemplate(block_scaled_template, block_scaled_values) + + + @staticmethod + def pointerize_if_grouped(operation, layout): + return layout if not is_grouped(operation.gemm_kind) else layout + "* " + + @staticmethod + def transform_layout_A_if_blockwise(operation, layout): + layout_sfa = f"{operation.procedural_name()}_LayoutSFA" + layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* " + return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>" + + @staticmethod + def transform_layout_B_if_blockwise(operation, layout): + layout_sfb = f"{operation.procedural_name()}_LayoutSFB" + layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* " + return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>" + + @staticmethod + def problem_shape(operation): + gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" + + return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type + + def emit(self, operation): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") + _LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name()) + _LOGGER.debug("*** tile_shape: " + str(operation.tile_description.tile_shape)) + _LOGGER.debug("*** warp_count: " + str(operation.tile_description.warp_count)) + + opcode_class_main = operation.tile_description.math_instruction.opcode_class + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + instruction_shape = operation.tile_description.math_instruction.instruction_shape + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] + tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape() + + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + elif opcode_class_main == OpcodeClass.SparseTensorOp and operation.arch == 100: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveoutEpi<{str(operation.procedural_name())}_epilogue>" + else: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + + instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ + (operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + # + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + grouped = is_grouped(operation.gemm_kind) + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped): + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] + # SM103 FP4 Ultra + is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped) + ] + is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped) + ] + if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule: + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule: + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] + + element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' + element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' + + alignment_c = get_tma_alignment(operation.C.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.C.alignment + alignment_d = get_tma_alignment(operation.D.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.D.alignment + + operation_name_str = operation.procedural_name() + layout_a_str = LayoutTag[instance_layout_A] + layout_b_str = LayoutTag[instance_layout_B] + mixed_dtype_prepare_code = "" + if operation.mixed_input_mode != None: + A_dtype = operation.A.element + B_dtype = operation.B.element + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + narrow_tag = DataTypeTag[narrow_dtype] + wide_tag = DataTypeTag[wide_dtype] + scale_tag = DataTypeTag[wide_dtype] + zero_tag = DataTypeTag[wide_dtype] + + do_shuffle = False + value_shuffle_str = "" + if narrow_dtype_bits == 4 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + if narrow_dtype_bits == 8 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + do_shuffle = operation.mixed_input_shuffle and do_shuffle + + if do_shuffle: + if is_A_dtype_narrow: + stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" + layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" + else: + stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" + layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" + # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and + # layout_{a, b}_str are to prevent errors in Windows platform unity build + mixed_dtype_prepare_code = f""" +using {operation_name_str}_StrideNarrow = {stride_narrow_str}; +using {operation_name_str}_ValueShuffle = {value_shuffle_str}; +static constexpr int {operation_name_str}_NumShuffleAtoms = 1; +using {operation_name_str}_MmaAtomShape = cute::Layout>>; +using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>()); +using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout, {operation_name_str}_StrideNarrow>{{}})); + """ + + mixed_input_modes_to_element = { + MixedInputMode.ConvertOnly: narrow_tag, + MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", + MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>" + } + narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag) + + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" + + if is_A_dtype_narrow: + element_a = narrow_element + else: + element_b = narrow_element + + blockwise_prepare_code = "" + if is_blockwise(operation.gemm_kind): + sfm_vec_size = operation.ScaleFactorMVecSize + sfn_vec_size = operation.ScaleFactorNVecSize + sfk_vec_size = operation.ScaleFactorKVecSize + blockwise_prepare_code = f""" +using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>; +using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA()); +using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB()); + """ + + values = { + 'operation_name': operation_name_str, + 'operation_suffix': self.operation_suffix, + 'problem_shape': self.problem_shape(operation), + 'element_a': element_a, + 'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)), + 'element_b': element_b, + 'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]), + 'element_d': DataTypeTag[operation.D.element], + 'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]), + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class_main': OpcodeClassTag[opcode_class_main], + 'opcode_class_epi': OpcodeClassTag[opcode_class_epi], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'tile_shape_m': str(tile_shape_m), + 'tile_shape_n': str(tile_shape_n), + 'tile_shape_k': str(tile_shape_k), + 'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int", + 'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int", + 'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int", + 'instruction_shape_m': str(instruction_shape[0]), + 'instruction_shape_n': str(instruction_shape[1]), + 'instruction_shape_k': str(instruction_shape[2]), + 'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]), + 'epilogue_schedule' : str(epilogue_schedule_type), + 'epi_tile_mn' : epi_tile_mn, + 'epilogue_functor': epilogue_functor, + 'stages': stage_count_string, + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'align_c': str(alignment_c), + 'align_d': str(alignment_d), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]), + 'mixed_dtype_prepare_code': mixed_dtype_prepare_code, + 'blockwise_prepare_code' : blockwise_prepare_code + } + + return SubstituteTemplate(self.gemm_template, values) + +################################################################################################### + +# +class EmitGemmPlanarComplexInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmKernel; + + struct ${operation_name} : + public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# +class EmitGemmPlanarComplexArrayInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmArrayKernel; + + struct ${operation_name} : public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# +class EmitGemmGroupedInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/kernel/gemm_grouped.h", + "cutlass/gemm/kernel/default_gemm_grouped.h", + "cutlass/gemm/device/gemm_grouped.h" + ] + self.builtin_epilogue_functor_template = \ +"""${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmGrouped< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${scheduler_mode}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmGrouped<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor + } + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(self.gemm_template, values) + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitGemmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + GemmKind.Gemm: EmitGemmInstance, + GemmKind.Sparse: EmitSparseGemmInstance, + GemmKind.Universal: EmitGemmUniversalInstance, + GemmKind.Universal3x: EmitGemmUniversal3xInstance, + GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, + GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, + GemmKind.Grouped: EmitGemmGroupedInstance, + GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance, + } + + self.gemm_kind_wrappers = { + GemmKind.Gemm: 'GemmOperation', + GemmKind.Sparse: 'GemmSparseOperation', + GemmKind.Universal: 'GemmUniversalOperation', + GemmKind.Universal3x: 'GemmUniversal3xOperation', + GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation', + GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation', + GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', + GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', + GemmKind.Grouped: 'GemmGroupedOperation', + GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation', + GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation', + GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation', + GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation', + } + + self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" + + self.separator = """ +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.header_template = """ +/* + Generated by gemm_operation.py - Do not edit. +*/ +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::__enter__") + _LOGGER.debug("*** configuration_path (file to write): " + + str(self.configuration_path)) + + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + self.configuration_file.write(self.separator) + + self.includes = collections.OrderedDict([ + ("cutlass/cutlass.h", None), + ("cutlass/library/library.h", None), + ("cutlass/library/manifest.h", None), + ("library_internal.h", None), + ("gemm_operation.h", None), + ("gemm_operation_3x.hpp", None), + ("grouped_gemm_operation_3x.hpp", None), + ("sparse_gemm_operation_3x.hpp", None), + ("block_scaled_gemm_operation_3x.hpp", None), + ("blockwise_gemm_operation_3x.hpp", None), + ("cutlass/arch/wmma.h", None), + ("cutlass/numeric_types.h", None) + ]) + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") + _LOGGER.debug("*** operation.gemm_kind: " + str(operation.gemm_kind)) + + emitter = self.instance_emitter[operation.gemm_kind]() + + for incl in emitter.includes: + self.includes[incl] = None + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(emitter.instance_template(), { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write includes + for incl, _ in self.includes.items(): + include_statement = "#include \"%s\"\n" % incl + self.configuration_file.write(include_statement) + + self.configuration_file.write(self.separator) + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..063e8fb1caa6626e8ba099133fee4dd3dc115e40 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py @@ -0,0 +1,10962 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library kernels +""" + +import argparse +import enum +from itertools import chain, product +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Dict, Optional, Sequence, Tuple + +_LOGGER = logging.getLogger(__name__) + +def logging_prefix(indent_level: int = 0) -> str: + """String prefix for start of each debug log entry""" + prefix = '*** ' + indent = ' ' + return f"{prefix}{indent_level * indent}" + +def log_debug_line(line: str, indent_level: int = 0) -> None: + """Log one line of debug output""" + prefix = logging_prefix(indent_level) + _LOGGER.debug(prefix + line) + +# Certain usecases of cutlass_library nearly always prefer to run as scripts with +# relative imports, rather than via an installed Python package. An example of this +# is using CUTLASS's CMake system to generate a library of kernels to be profiled. +# To make it easy to use these use cases when an existing installation of cutlass_library +# exists, this global flag can be set to true (via command-line arguments) to ensure +# that package-based installations are not used. + +# Create a temporary argument parser to check only for the availability of the +# --disable-cutlass-package-imports argument, which controls whether package-based +# imports are disabled. +def _add_package_disablement_flag(argparser): + argparser.add_argument("--disable-cutlass-package-imports", action='store_true', required=False, + help="Disable use of cutlass_library from Python package") + +_parser = argparse.ArgumentParser() +_add_package_disablement_flag(_parser) +_args, _ = _parser.parse_known_args() + +# Add `CUTLASS_IGNORE_PACKAGE` to `builtins` so that it is visible for gating future +# imports without requiring importing another module. Ideally, we would just place this +# as a global variable in a module to that could be imported and checked (e.g., +# utils.CUTLASS_IGNORE_PACKAGE). However, this raises the issue of determining +# where this module should be sourced (from the cutlass_library package or from +# a relative import), which is the problem this variable is being used to solve in the +# first place. +import builtins +builtins.CUTLASS_IGNORE_PACKAGE = _args.disable_cutlass_package_imports + +try: + if CUTLASS_IGNORE_PACKAGE: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.manifest import * + from cutlass_library.heuristics import * + from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist +except ImportError: + from library import * + from manifest import * + from heuristics import * + from emit_kernel_listing import emit_gemm_kernel_testlist +################################################################################################### + +# +def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): + + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] + + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + +# From cuda 13.0, Thor SM is renumbered from 101 to 110 +def ThorSMRenumbering(cuda_version): + return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101 + +################################################################################################### +################################################################################################### + +# +def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8): + ''' Helper to compute the maximum alignment of the epilogue ''' + + def product(X, identity = 1): + result = identity + for item in X: + result *= item + return result + + elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps + return min(max_alignment, elements_per_thread) + +def DefaultSwizzlingFunctor(): + return SwizzlingFunctor.Identity8 + # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` + +# +def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = DefaultSwizzlingFunctor()): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + # If alignment is a tuple or a list, then we have different alignments for A and B + alignment_a = alignment if isinstance(alignment, int) else alignment[0] + alignment_b = alignment if isinstance(alignment, int) else alignment[1] + alignment_c = min(8, alignment_a) if isinstance(alignment, int) else alignment[2] + + A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts +def CreateGemmUniversal3xOperator( + manifest, layouts, tile_descriptions, data_types, + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1, + tile_schedulers=[TileSchedulerType.Default], + gemm_kind=GemmKind.Universal3x): + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedules: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + if len(tile_descriptions) == 0: + return operations + tile_descriptions = [tile_descriptions[0]] + + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) + for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: + kernel_schedule, epilogue_schedule = schedules + A = TensorDescription( + data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) + + C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) + D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + + gemm_op_extra_args = {} + element_compute = data_type.get("epi_type", data_type["acc_type"]) + + if "sf_type" in data_type: + gemm_op_extra_args["ScaleFactorA"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]), + "vector_size" : data_type["sfd_type"]["vector_size"]} + assert is_block_scaled(gemm_kind) + + if tile_description.explicit_vector_sizes != None: + assert len(tile_description.explicit_vector_sizes) == 3 + gemm_op_extra_args["ScaleFactorMVecSize"] = tile_description.explicit_vector_sizes[0] + gemm_op_extra_args["ScaleFactorNVecSize"] = tile_description.explicit_vector_sizes[1] + gemm_op_extra_args["ScaleFactorKVecSize"] = tile_description.explicit_vector_sizes[2] + assert is_blockwise(gemm_kind) + else: + assert not is_blockwise(gemm_kind) + + A_dtype = data_type["a_type"] + B_dtype = data_type["b_type"] + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + mixed_input_modes = [None] + if narrow_dtype_bits != wide_dtype_bits: + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + mixed_input_modes = [MixedInputMode.ScaleOnly] + else: + mixed_input_modes = [MixedInputMode.ConvertOnly, MixedInputMode.ScaleOnly, MixedInputMode.ScaleWithZeroPoint] + + mixed_input_shuffle_options = [False] + if (mixed_input_modes[0] is not None) and (wide_dtype_bits == 16) and (narrow_dtype_bits == 4 or narrow_dtype_bits == 8): + mixed_input_shuffle_options = [False, True] + + for mixed_input_mode, mixed_input_shuffle in product(mixed_input_modes, mixed_input_shuffle_options): + operation = GemmOperation( + gemm_kind, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler, + mixed_input_mode=mixed_input_mode, mixed_input_shuffle=mixed_input_shuffle, **gemm_op_extra_args) + manifest.append(operation) + operations.append(operation) + + return operations + +# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts +def CreateSparseGemmUniversal3xOperator( + manifest, layouts, tile_descriptions, data_types, + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1, + tile_schedulers=[TileSchedulerType.Default]): + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedules: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0]] + + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) + for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: + kernel_schedule, epilogue_schedule = schedules + A = TensorDescription( + data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) + + # Currently assume tensor C/D have same layout requirement. + C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) + D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + + element_compute = data_type.get("epi_type", data_type["acc_type"]) + + operation = GemmOperation( + GemmKind.SparseUniversal3x, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler) + + manifest.append(operation) + operations.append(operation) + + return operations + +# +def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.Sparse] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Sparse, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for gemm_kind in gemm_kinds: + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + manifest.append(GemmOperation(gemm_kind, \ + tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue)) + return + +# +def CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GroupedGemmOperation(GemmKind.Grouped, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_type, \ + alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + element_a, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for fill_mode in fill_modes: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + + # SERK supported layouts (RowMajor, ColumnMajor) with no conjugation + complex_transform = ComplexTransform.none + + # HERK supported layouts (RowMajor + conj, ColumnMajor) + if blas_mode == BlasMode.hermitian and layout[0] == LayoutType.RowMajor: + complex_transform = ComplexTransform.conj + + alignment_c = 1 # Alignment only applies to A in SYRK + + A = TensorDescription(element_a, layout[0], alignment, complex_transform) + C = SymmetricTensorDescription(element_c, layout[1], fill_mode, alignment_c) + + # Rank-K update + new_operation = RankKOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + # Rank-2K update + new_operation = Rank2KOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for side_mode in side_modes: + for fill_mode in fill_modes: + for diag_type in diag_types: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TriangularTensorDescription(element_a, layout[0], side_mode, fill_mode, diag_type, + alignment, complex_transform) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = TrmmOperation(TrmmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, data_type, \ + alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for side_mode in side_modes: + for fill_mode in fill_modes: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + + # SYMM supported layouts (RowMajor, ColumnMajor) with no conjugation + complex_transform = ComplexTransform.none + + alignment_a = 1 # No vectorized access for the triangular matrix + alignment_c = min(8, alignment) + + A = SymmetricTensorDescription(element_a, layout[0], fill_mode, alignment_a, complex_transform, side_mode) + # tensor A and B have same data type and layout + B = TensorDescription(element_b, layout[0], alignment) + C = TensorDescription(element_c, layout[1], alignment_c) + + # SYMM/HEMM update + new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + # SYMM/HEMM update + new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +########################################################################################################### +# ConvolutionOperator support variations +# ____________________________________________________________________ +# ConvolutionalOperator | Analytic | Optimized +# ____________________________________________________________________ +# | Fprop | (strided) | (strided) +# | Dgrad | (strided, unity*) | (strided, unity) +# | Wgrad | (strided) | (strided) +# ____________________________________________________________________ +# +# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low +########################################################################################################### +# Convolution for 2D operations +def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + iterator_algorithms = [IteratorAlgorithm.Optimized] + + operations = [] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operations = [ + # None grouped kernel + Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_), + ] + + # Instance group conv kernel + if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \ + tile.minimum_compute_capability >= 80: + # SingleGroup kernel + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) + + # Analytic iterator supports MultipleGroup mode + if iterator_algorithm == IteratorAlgorithm.Analytic: + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup)) + + for new_operation in new_operations: + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Dgrad + # + if ConvKind.Dgrad in conv_kinds: + + # Unity stride for Analytic and Optimized Dgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # strided dgrad uses a special threadblock swizzle + # note that SwizzlingFunctor.StridedDgradHorizontal might be + # better for problem sizes with large activation channel count + swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 + + if IteratorAlgorithm.Analytic in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Optimized Dgrad + if IteratorAlgorithm.Optimized in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 2D operations specialized for few channels +def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.FixedChannels,] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + channel_counts = [channel_counts[0],] + + operations = [] + + + + for tile in tile_descriptions: + for channel_count in channel_counts: + + alignment_c = EpilogueAlignment(channel_count, tile) + + A = TensorDescription(element_a, layout[0], channel_count) + B = TensorDescription(element_b, layout[1], channel_count) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 2D operations specialized for few channels +def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.FewChannels,] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + channel_counts = [channel_counts[0],] + + operations = [] + + for tile in tile_descriptions: + for channel_count in channel_counts: + + alignment_c = EpilogueAlignment(channel_count, tile) + + A = TensorDescription(element_a, layout[0], channel_count) + B = TensorDescription(element_b, layout[1], channel_count) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 3D operations +def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + alignment_c = min(8, alignment) + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size and optimized iterators + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + iterator_algorithms = [IteratorAlgorithm.Optimized] + + operations = [] + + # All tile sizes for Conv3dFprop and Conv3dWgrad + for tile in tile_descriptions: + A = TensorDescription(element_a, layout, alignment) + B = TensorDescription(element_b, layout, alignment) + C = TensorDescription(element_c, layout, alignment_c) + + # + # Conv3d Fprop + # + if ConvKind.Fprop in conv_kinds: + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv3dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided) + manifest.append(new_operation) + operations.append(new_operation) + # + # Conv3d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv3dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) + manifest.append(new_operation) + operations.append(new_operation) + + # All tile sizes for Conv3dDgrad + for tile in tile_descriptions: + + A = TensorDescription(element_a, layout, alignment) + B = TensorDescription(element_b, layout, alignment) + C = TensorDescription(element_c, layout, alignment_c) + + # + # Conv3d Dgrad + # + if ConvKind.Dgrad in conv_kinds: + # Unity stride for Optimized Dgrad + new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # Conv3dDgrad has a naive strided support which does not cut down redundant MMAs + new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for Depthwise 2d conv +def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # iterator algorithm (FixedStrideDilation, Optimized) + iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + operations = [] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + if ConvKind.Fprop in conv_kinds: + + # Strided support for Optimized and FixedStridedDilation Depthwise Conv + for iterator_algorithm in iterator_algorithms: + stride_support = StrideSupport.Strided + if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation: + if tile.stride == [-1, -1] or tile.dilation == [-1,-1]: + continue + stride_support = StrideSupport.Fixed + + if iterator_algorithm == IteratorAlgorithm.Optimized: + if tile.stride != [-1, -1] or tile.dilation != [-1,-1]: + continue + new_operation = Conv2dOperation(ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, B, C, + element_epilogue, + stride_support, + epilogue_functor, + swizzling_functor_, + group_mode=GroupMode.Depthwise) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +class ConvOperation3x: + """All parameters of a CUTLASS 3 convolution operation. + + Unlike CUTLASS 2 convolutions, CUTLASS 3 convolutions do not + distinguish between 2-D and 3-D convolutions by kernel class name. + Instead, for CUTLASS 3 convolutions, the tensor layouts encode + whether the convolution is 2-D or 3-D. Thus, this class deduces + the OperationKind (either Conv2d or Conv3d) from the layouts, + rather than taking it as a constructor parameter. + """ + def __init__(self, + conv_kind: ConvKind, + tile_description: TileDescription, + A: TensorDescription, + B: TensorDescription, + C: TensorDescription, + element_compute: Optional[DataType] = None, + D: Optional[TensorDescription] = None, + kernel_schedule: KernelScheduleType = KernelScheduleType.ScheduleAuto, + epilogue_schedule: EpilogueScheduleType = EpilogueScheduleType.ScheduleAuto, + tile_scheduler: TileSchedulerType = TileSchedulerType.Default, + log_indent_level: int = 1): + log_debug_line(f'ConvOperation3x::init: conv_kind: {conv_kind}', log_indent_level) + log_indent_level = log_indent_level + 1 + + self.conv_kind = conv_kind + self.tile_description = tile_description + self.A = A + self.B = B + self.C = C + self.element_compute = C.element if element_compute is None else element_compute + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + + self.arch = tile_description.minimum_compute_capability + self.tile_scheduler = tile_scheduler + if D == None: + self.D = C + else: + self.D = D + + self.is_3x = True + self.group_mode = GroupMode.NoneGroup # CUTLASS 3 convolutions currently aren't grouped + + operation_kind = None + for layout in (A.layout, B.layout, C.layout): + assert(isinstance(layout, LayoutType)) + new_operation_kind = convolution_tensor_layout_type_to_operation_kind(layout) + if operation_kind is None: + operation_kind = new_operation_kind + else: # CUTLASS 3 convolutions don't permit mixing 2-D and 3-D layouts. + assert(operation_kind == new_operation_kind) + assert(operation_kind is not None) + self.operation_kind = operation_kind + + def __str__(self): + return f"ConvOperation3x: operation_kind={self.operation_kind}, conv_kind={self.conv_kind}, tile_description={self.tile_description}" + + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + def is_mixed_input(self): + return self.A.element != self.B.element + + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + if self.is_complex(): + return get_complex_from_real(accum) + return accum + + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and', + } + + tensor_ops = [ + OpcodeClass.TensorOp, + OpcodeClass.WmmaTensorOp, + OpcodeClass.SparseTensorOp, + OpcodeClass.BlockScaledTensorOp, + ] + + is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops + + if is_tensor_op: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s" % (math_op_string, intermediate_type, ConvKindNames[self.conv_kind]) + + def extended_name(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}{layout_a}_{element_b}{layout_b}_{element_acc}_{element_c}_{element_d}{layout_c}".format( + element_a = DataTypeNames[self.A.element], + layout_a = ShortLayoutTypeNames[self.A.layout], + element_b = DataTypeNames[self.B.element], + layout_b = ShortLayoutTypeNames[self.B.layout], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + layout_c = ShortLayoutTypeNames[self.C.layout], + element_d = DataTypeNames[self.D.element], + core_name = self.core_name()) + + return extended_name + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name(self): + return KernelScheduleSuffixes[self.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name(self): + return EpilogueScheduleSuffixes[self.epilogue_schedule] + + # Generate a short string representing the operation class + def opcode_class_name(self): + return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + # Generates the full kernel function name + def configuration_name(self): + ''' The full function name indicates architecture, extended name, tile size, and layout. ''' + kernel_name_template = "cutlass3x_sm{ar}_{op}_{ex}{ct}{cs}_{l}_align{al}{t}{k}{e}" + return kernel_name_template.format( + ar = self.arch, + op = self.opcode_class_name(), + ex = self.extended_name(), + ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "", + cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), + l = self.tile_description.stages, + al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], + k = self.kernel_schedule_name(), + e = self.epilogue_schedule_name()) + + def procedural_name(self): + return self.configuration_name() + +def convolution_tensor_layout_type_to_operation_kind(layout: LayoutType) -> OperationKind: + if layout == LayoutType.TensorNHWC or layout == LayoutType.TensorKCSR: + return OperationKind.Conv2d + elif layout == LayoutType.TensorNDHWC or layout == LayoutType.TensorKCSRT: + return OperationKind.Conv3d + else: + raise RuntimeError(f'LayoutType {layout} does not have a corresponding OperationKind') + +def CreateConvOperator3x(manifest: Manifest, + dims_and_alignments: Sequence[Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]], + tile_descriptions: Sequence[Sequence[TileDescription]], + data_types, + schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \ + [(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)], + complex_transforms: Optional[Sequence[ComplexTransform]] = None, + tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default], + conv_kind: ConvKind = ConvKind.Fprop, + log_indent_level: int = 1): + """ + Create zero or more CUTLASS 3 two-dimensional convolution operators. + + Create a CUTLASS 3 two-dimensional convolution operator + for all feasible combinations of the input parameters. + Add the operators to the manifest. + + dims_and_alignments: 3-level list. Each outer list term is a list [A, B, C]. + Each inner list (A, B, or C) has the form [num_spatial_dimensions, alignment]. + Both are integers; the first is the number of spatial dimensions + (currently, only 2 or 3 are supported), and the second is the byte alignment. + We deduce the operation_kind (either OperationKind.Conv2d or OperationKind.Conv3d) + from num_spatial_dimensions. + + This function doesn't take layouts, unlike the GEMM functions. + CUTLASS 3 convolutions currently support three input layouts: + + * TensorNWC for 1-D convolutions, + * TensorNHWC for 2-D convolutions, and + * TensorNDHWC for 3-D convolutions. + + Output (C and D) layouts are the same as input layouts, + except for Wgrad convolutions, where the layouts are + + * TensorKCS for 1-D convolutions, + * TensorKCSR for 2-D convolutions, and + * TensorKCSRT for 3-D convolutions. + + The output layouts are completely constrained by the input layouts + and the convolution kind. + + tile_descriptions: 2-level list. + Outer level has one list per math instruction. + Inner level has one TileDescription for each cluster shape. + + data_types: Either a single data_type dictionary, or a list of them. + Keys: 'a_type', 'b_type', 'c_type', 'd_type', 'acc_type', 'epi_type' + + complex_transforms: Optional list of pairs. + First element of each pair is the complex transform for A, and + second element of each pair is the complex transform for B. + + schedule_pairs: [(kernel_schedule, epilogue_schedule), ...] + + conv_kind: Convolution kind (Fprop, Dgrad, or Wgrad). + """ + log_debug_line('CreateConvOperator3x', log_indent_level) + log_indent_level = log_indent_level + 1 + log_debug_line(f'conv_kind: {conv_kind}', log_indent_level) + + for triple in dims_and_alignments: + assert(isinstance(triple, tuple) or isinstance(triple, list)) + assert(len(triple) == 3) + + spatial_dimensionality = None # to be determined by loop below + + for entry in triple: # [A, B, C] + assert(len(entry) == 2) + [dim, alignment] = entry + assert(type(dim) is int) + assert(dim == 2 or dim == 3) + assert(type(alignment) is int) + assert(alignment > 0) + if spatial_dimensionality is None: + spatial_dimensionality = dim + else: + # A, B, and C need to have the same spatial dimensionality + assert(spatial_dimensionality == dim) + + def input_and_output_layouts(spatial_dim: int, kind: ConvKind) -> Tuple[LayoutType, LayoutType]: + if spatial_dim == 1: + input_layout = LayoutType.TensorNWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCS + else: + output_layout = input_layout + elif spatial_dim == 2: + input_layout = LayoutType.TensorNHWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCSR + else: + output_layout = input_layout + elif spatial_dim == 3: + input_layout = LayoutType.TensorNDHWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCSRT + else: + output_layout = input_layout + else: + assert(False) + return (input_layout, output_layout) + + def dims_to_layouts(A_B_C: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]) -> \ + Tuple[Tuple[LayoutType, int], Tuple[LayoutType, int], Tuple[LayoutType, int]]: + [A, B, C] = A_B_C + [spatial_dim, alignment] = A + [input_layout, output_layout] = input_and_output_layouts(spatial_dim, conv_kind) + return ((input_layout, A[1]), + (input_layout, B[1]), + (output_layout, C[1])) + + # layouts: list of triples (A, B, C). + # Each of A, B, and C has the form [layout, alignment]. + layouts = [dims_to_layouts(A_B_C) for A_B_C in dims_and_alignments] + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedule_pairs: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none)] + + # product produces a one-pass generator, so the loop must call it anew each time. + def make_combinations(): + return product( + layouts, + tile_descriptions, + data_types, + complex_transforms, + schedule_pairs, + tile_schedulers + ) + + operations = [] + for layout_triple, tile_description, data_type, complex_transform_pair, schedule_pair, tile_scheduler in make_combinations(): + A_layout, A_alignment = layout_triple[0] + A_xform = complex_transform_pair[0] + B_layout, B_alignment = layout_triple[1] + B_xform = complex_transform_pair[1] + C_layout, C_alignment = layout_triple[2] + D_layout = C_layout + D_alignment = C_alignment + + A = TensorDescription(data_type["a_type"], A_layout, A_alignment, A_xform) + B = TensorDescription(data_type["b_type"], B_layout, B_alignment, B_xform) + C = TensorDescription(data_type["c_type"], C_layout, C_alignment) + D = TensorDescription(data_type["d_type"], D_layout, D_alignment) + element_compute = data_type.get("epi_type", data_type["acc_type"]) + kernel_schedule, epilogue_schedule = schedule_pair + + operation = ConvOperation3x(conv_kind=conv_kind, + tile_description=tile_description, + A=A, + B=B, + C=C, + element_compute=element_compute, + D=D, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + log_indent_level=log_indent_level) + log_debug_line(f'Created ConvOperation3x: {str(operation)}', log_indent_level) + manifest.append(operation) + operations.append(operation) + + return operations + +################################################################################################### +################################################################################################### + +# +def GenerateSM50_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 50 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + if math_inst.element_a == DataType.f32: + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM50_Simt_complex(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 50 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32, + ] + + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM50(manifest, cuda_version): + GenerateSM50_Simt(manifest, cuda_version) + GenerateSM50_Simt_complex(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM60_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 60 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# +def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version): + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 60 + max_cc = 1024 + + alignment_constraints = [8,] + + filter_3x3 = [3, 3] + filter_5x5 = [5, 5] + + # [stride_h, stride_w] + # [-1, -1] means all stride size. + strides = [[-1,-1], [1, 1], [2, 2]] + # [dilation_h, dilation_w] + # [-1, -1] means all dilation size. + dilations = [[-1,-1], [1, 1], [2, 2]] + + #groups per thread block + g16 = 16 + g32 = 32 + g64 = 64 + + #output shape per thread block + npq_1x4x4 = [1, 4, 4] + npq_1x8x8 = [1, 8, 8] + npq_1x10x10 = [1, 10, 10] + + tile_descriptions = [] + for math_inst in math_instructions: + for stride, dilation in product(strides, dilations): + tile_descriptions.extend([ + # filter3x3 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + + # filter5x5 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc) + ]) + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM60(manifest, cuda_version): + GenerateSM60_Simt(manifest, cuda_version) + GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM61_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 4], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 61 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) +# + +# +def GenerateSM61(manifest, cuda_version): + GenerateSM61_Simt(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM70_TensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + +# +def GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + + +# +def GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version): + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 1024 + + alignment_constraints = [8,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# +################################################################################################## +# + +def GenerateSM70(manifest, cuda_version): + GenerateSM70_TensorOp_884(manifest, cuda_version) + GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version) + + # To limit build size, WMMA GEMMs are disabled for now. + # + #GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst): + + min_cc = 75 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 64, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 2], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) + CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [1, 2, 4]) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) + CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [1, 2, 4]) + +# +def GenerateSM75_TensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + + # Separate generator for 'few channels' specializations + GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst) + +# + +# +def GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 90 + + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), + + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 90 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 +# + +# +def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 89 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 89 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 +# + +# +def GenerateSM75_TensorOp_88128(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 128], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + ] + + min_cc = 75 + max_cc = { + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 + } + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) +# + +# +def GenerateSM75_Simt_complex(manifest, cuda_version): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc) + ] + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +def GenerateSM75(manifest, cuda_version): + GenerateSM75_TensorOp_1688(manifest, cuda_version) + GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version) + GenerateSM75_TensorOp_8816_TN(manifest, cuda_version) + GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version) + GenerateSM75_TensorOp_8832_TN(manifest, cuda_version) + GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version) + GenerateSM75_TensorOp_88128(manifest, cuda_version) + #GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version) + GenerateSM75_Simt_complex(manifest, cuda_version) + + +################################################################################################### +################################################################################################### + +# +def GenerateSM80_TensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, 4, 2] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) + CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) + CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8) +# + +# +def GenerateSM80_SparseTensorOp_16832(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + +# +def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, ] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 8, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + # 128x128 + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_b != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + math_inst.element_accumulator, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + for op in operations: + if (DataTypeSize[op.C.element] == 16) and \ + (op.tile_description.threadblock_shape[1] <= 32): + op.C.alignment = 4 + +# +def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[8, 16, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + # 128x128 + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 9, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), + # 256x16 + TileDescription([256, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] <= 32: + op.C.alignment = 4 + +# +def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + smem_usage = 164 + + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s4, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[32, 16, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[32, 16, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + DataType.f32 + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand B + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 32, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[16, 32, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 64], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + tile_descriptions = [ + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s8, DataType.s8, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s8, DataType.s8, DataType.s8, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 128], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + tile_descriptions = [ + TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s4, DataType.s4, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s4, DataType.s4, DataType.s4, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] > 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 +# + +# +def GenerateSM80_TensorOp_168256(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.and_popc), + ] + + min_cc = 80 + max_cc = { + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 + } + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 512], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 512], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 512], 10, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 1024], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 1024], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM80_TensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f16), + MathInstruction( \ + [16, 8, 8], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_bf16), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +def GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32) + + min_cc = 80 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + +# +def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1, 2, 4] # Alignment only applies to A in SYRK + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + # SYRK + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1, 2, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_1688_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + # A and B have same layouts + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [ + 1, 2, 4 + ] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + # SYMM + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_884_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) + +# + +# +def GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + + +# +def GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_884_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +################################################################################################### + +# +def GenerateSM80_Simt_f32(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 5, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + + +# +def GenerateSM80_Simt_f64(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + + +################################################################################################## +# +def GenerateSM80_Simt_complex(manifest, cuda_version): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + for math_inst in math_instructions: + + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +################################################################################################### + +# +def GenerateSM80(manifest, cuda_version): + GenerateSM80_TensorOp_16816(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16832(manifest, cuda_version) + GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version) + GenerateSM80_TensorOp_1688(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version) + GenerateSM80_TensorOp_1688_complex(manifest, cuda_version) + # 3xTF32 + GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version) + GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version) + GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_symm(manifest, cuda_version) + GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884(manifest, cuda_version) + GenerateSM80_TensorOp_884_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) + GenerateSM80_TensorOp_16864_TN(manifest, cuda_version) + GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version) + GenerateSM80_TensorOp_168256(manifest, cuda_version) + GenerateSM80_Simt_f32(manifest, cuda_version) + GenerateSM80_Simt_f64(manifest, cuda_version) + GenerateSM80_Simt_complex(manifest, cuda_version) + +################################################################################################### + +def GenerateSM89_TensorOp_16832_fp8(manifest, element_acc): + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) + ] + + math_instructions = [ + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + ] + + min_cc = 89 + max_cc = 100 + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_types = [ + [ + math_inst.element_a, + math_inst.element_b, + DataType.f32, + math_inst.element_accumulator + ], + [ + math_inst.element_a, + math_inst.element_b, + DataType.bf16, + math_inst.element_accumulator + ], + ] + + operations = [] + for data_type in data_types: + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, + alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +def GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 4): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f32) + +def GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f16) + +# +def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version): + + if ( + not CudaToolkitVersionSatisfies(cuda_version, 12, 4) + ): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) + ] + + math_instructions = [ + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + ] + + min_cc = 89 + max_cc = 89 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_types = [ + [ + math_inst.element_a, + math_inst.element_b, + DataType.f32, + math_inst.element_accumulator + ], + ] + + operations = [] + for data_type in data_types: + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, + alignment_constraints, None, EpilogueFunctor.LinearCombination) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +################################################################################################### + +# +def GenerateSM89(manifest, cuda_version): + GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version) + GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version) + GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version) + +################################################################################################### + + +try: + from .sm90_utils import ( + generate_fp16_bf16_math_instructions_sm90, + generate_tf32_math_instructions_sm90, + generate_int8_math_instructions_sm90, + generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, + make_sparse_math_instructions, + generate_tile_descriptions_sm90, + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) +except ImportError: + from sm90_utils import ( + generate_fp16_bf16_math_instructions_sm90, + generate_tf32_math_instructions_sm90, + generate_int8_math_instructions_sm90, + generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, + make_sparse_math_instructions, + generate_tile_descriptions_sm90, + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) + +def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_types = [data_type_w_source, data_type_wo_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_type_mixed_wo_source = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + data_types.append(data_type_mixed_wo_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_types = [data_type_w_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = make_sparse_math_instructions(generate_fp16_bf16_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_types = [data_type_w_source, data_type_wo_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_type_mixed_wo_source = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + data_types.append(data_type_mixed_wo_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + ] + + math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + data_type_f32_wo_source["a_type"] = DataType.f32 + data_type_f32_wo_source["b_type"] = DataType.f32 + data_type_f32_wo_source["epi_type"] = DataType.f32 + data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] + + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + # Inconsistency: TF32 does not stamp out void-C + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + for data_type in [data_type_tf32, data_type_f32]: + # Inconsistency: alignments aren't fixed in TF32 / alignx + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + ] + + math_instructions = make_sparse_math_instructions(generate_tf32_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + data_type_f32_wo_source["a_type"] = DataType.f32 + data_type_f32_wo_source["b_type"] = DataType.f32 + data_type_f32_wo_source["epi_type"] = DataType.f32 + data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] + + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + math_instructions = generate_int8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_int8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + math_instructions = make_sparse_math_instructions(generate_int8_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + # s8.u8 and u8.s8 wgmma variants require PTX 8.4 + if math_inst.element_a != math_inst.element_b and not CudaToolkitVersionSatisfies(cuda_version, 12, 4): + continue + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions_ = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + tile_descriptions = list() + + for desc in tile_descriptions_: + desc.explicit_vector_sizes = [1, desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [1, 1, desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + enable_fp8_fast_acc=False, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK], + gemm_kind=gemm_kind) + + + +def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], # TN Layout + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [generate_data_types_from_math_instruction(math_inst)] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 1): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) + is_aligned = True + + # layouts for ABC, their alignments will be fixed later based on the data type + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + valid_types_for_a_b_acc = [ + (DataType.e4m3, DataType.f16, DataType.f32), + (DataType.e4m3, DataType.bf16, DataType.f32), + (DataType.e5m2, DataType.f16, DataType.f32), + (DataType.e5m2, DataType.bf16, DataType.f32), + (DataType.s8, DataType.f16, DataType.f32), + (DataType.s8, DataType.bf16, DataType.f32), + (DataType.u8, DataType.f16, DataType.f32), + (DataType.u8, DataType.bf16, DataType.f32), + (DataType.s4, DataType.f16, DataType.f32), + (DataType.s4, DataType.bf16, DataType.f32), + (DataType.s4, DataType.e4m3, DataType.f32), + (DataType.s4, DataType.e5m2, DataType.f32), + (DataType.u4, DataType.f16, DataType.f32), + (DataType.u4, DataType.bf16, DataType.f32), + (DataType.u2, DataType.f16, DataType.f32), + (DataType.u2, DataType.bf16, DataType.f32), + (DataType.s2, DataType.f16, DataType.f32), + (DataType.s2, DataType.bf16, DataType.f32), + ] + # Note: For sizeof(a_type) > sizeof(b_type), some generated kernels might crash due to a compiler bug. Disable it for now. + #swapped_valid_types_for_a_b_acc = [(b_type, a_type, acc_type) for a_type, b_type, acc_type in valid_types_for_a_b_acc] + #valid_types_for_a_b_acc = valid_types_for_a_b_acc + swapped_valid_types_for_a_b_acc + + math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc) + + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + + # Limit C/D types to avoid a giant number of instantiations. + # A typical use case for mixed dtype in DL is weight quantization (tensor A), + # therefore we can limit the output type to that of activation (tensor B). + valid_types_for_c = [math_inst.element_b] + valid_types_for_d = [math_inst.element_b] + + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Fix alignments, DataTypeSize are in the unit of bits + alignment_bits = 128 + layout[0][1] = alignment_bits // DataTypeSize[data_type['a_type']] + layout[1][1] = alignment_bits // DataTypeSize[data_type['b_type']] + layout[2][1] = alignment_bits // DataTypeSize[data_type['c_type']] + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = make_sparse_math_instructions(generate_fp8_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_1684(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( + [16, 8, 4], + DataType.f64, DataType.f64, DataType.f64, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, + data_type, alignment_constraints) + +# + +# +def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) + +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + + +# +def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + + + +# Blackwell SM 100 generators + +try: + import cutlass_library.sm100_utils + from cutlass_library.sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) +except ImportError: + import sm100_utils + from sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) + +################################################################################################### + +def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int: + if DataTypeSize[data_type] < 8 and is_f8f6f4: + return int(128) + return int(16 * 8 / DataTypeSize[data_type]) + +sm100_cluster_shape_1sm = [ + [4,4,1] + , DynamicClusterShape +] + +sm100_cluster_shape_2sm = [ + # cluster_m % 2 == 0 for 2sm + [4,4,1] + , DynamicClusterShape +] + +def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + ] + + data_types = [ + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + math_instructions_1sm, math_instructions_2sm = generate_tf32_math_instructions_sm100(instantiation_level) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + + # layouts for ABC and their alignments. C alignment will be set later based on output type + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + math_instructions_1sm, math_instructions_2sm = generate_16b_math_instructions_sm100(instantiation_level) + + min_cc = 100 + max_cc = thor_sm + grouped = is_grouped(gemm_kind) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100 + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=591 , default_level=591 , exhaustive_level=9999) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + grouped = is_grouped(gemm_kind) + + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized1SmSm100, grouped) + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # 2xSM MMA kernels + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=593, default_level=593, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 100 + epi_type = DataType.f32 + + pruning_level = get_pruning_level_from_global_level(instantiation_level) + + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_compile_time_dtype=grouped or pruning_level >= 1, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + tile_schedulers = [ + TileSchedulerType.Default, + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [math_inst.instruction_shape[0], math_inst.instruction_shape[1], + math_inst.instruction_shape[2] * 4])) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [1, math_inst.instruction_shape[1], + math_inst.instruction_shape[2] * 4])) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [math_inst.instruction_shape[0], 1, + math_inst.instruction_shape[2] * 4])) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + is_runtime_datatype_a = is_runtime_datatype(data_type["a_type"]) + is_runtime_datatype_b = is_runtime_datatype(data_type["d_type"]) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped) + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + + # SM100 MMA with mixed F4/F6/F8 inputs + without block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_f8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter out some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + if math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers) + else: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + + # SM100 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_mxf8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) + + ab_types = [ + DataType.f4, DataType.f6, + DataType.e2m1, + DataType.e2m3, + DataType.e3m2, + DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + + # Set alignment d based on Destination format. + for data_type in data_types: + for layout in layouts: + # alignment for a + layout[0][1] = get_tma_alignment_elt(data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(data_type["d_type"]) + for tile in tile_descriptions: + math_inst = tile.math_instruction + # Filter some kernels that does not meet the alignment requirements. + if layout[0][0] == LayoutType.ColumnMajor: + if math_inst.instruction_shape[0] // 2 % layout[0][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[0][1] != 0: + continue + + if layout[1][0] == LayoutType.RowMajor: + if math_inst.instruction_shape[1] // 2 % layout[1][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0: + continue + + if grouped: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + elif math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + else: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + + +def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_mxf4nvf4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func=change_priority_func) + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + assert math_inst.instruction_shape[2] * 4 == 256 + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + # E2M1 x E2M1, vector size 16, UE4M3 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) + + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + + for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) + + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + +def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 13, 0): + return + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], + ] + + instruction_sizes_1sm = [ + [128, 128, 96], + ] + + instruction_sizes_2sm = [ + [256, 128, 96], + [256, 192, 96], + [256, 256, 96] + ] + + ab_types = [ + DataType.f4, + DataType.e2m1, + ] + + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if grouped: + return [TileSchedulerType.Default] + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 103 + max_cc = 103 + epi_type = DataType.f32 + + math_instructions_1sm = [] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, sf_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + math_instructions_2sm = [] + + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, sf_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + cluster_shapes_1sm = [ + [1,1,1], + # [1,2,1], + [2,1,1], + # [1,4,1], + [4,4,1], + DynamicClusterShape + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + 768], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epilogue_1sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), epilogue_1sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), epilogue_1sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + cluster_shapes_2sm = [ + [2,1,1], + # [2,2,1], + # [2,4,1], + [4,1,1], + # [4,2,1], + [4,4,1], + DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epilogue_2sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), epilogue_2sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), epilogue_2sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + +def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + math_instructions_1sm = [ + MathInstruction( + [64, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] + , DynamicClusterShape + ] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] + , DynamicClusterShape + ] + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + , DynamicClusterShape + ] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] + , DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + +def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.f16, + "b_type" : DataType.f16, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.f16, + "b_type" : DataType.f16, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.s8, + "b_type" : DataType.s8, + "c_type" : DataType.void, + "d_type" : DataType.s8, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.s8, + "b_type" : DataType.s8, + "c_type" : DataType.s8, + "d_type" : DataType.s8, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add)] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # NOTE: a/b type in kernel will be overwrite below. + #* void_c + # f8_f8_f32_void_f16 + { + "a_type" : DataType.e4m3, + "b_type" : DataType.e4m3, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + #* non-void_c + # f8_f8_f32_f16_f8 + { + "a_type" : DataType.e4m3, + "b_type" : DataType.e4m3, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + # Runtime DType + MathInstruction( + [128, 128, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + # Runtime DType + MathInstruction( + [256, 128, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update input AB type + kernel_data_type["a_type"] = math_inst.element_a + kernel_data_type["b_type"] = math_inst.element_b + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update input AB type + kernel_data_type["a_type"] = math_inst.element_a + kernel_data_type["b_type"] = math_inst.element_b + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + math_instructions_1sm = [ + # Runtime Dtype + MathInstruction( + [128, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [128, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + # Runtime DType + MathInstruction( + [256, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [256, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 128 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 128 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 256 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 256 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +# Conv Utility functions +def make_dims_and_alignments_triple(dim: int, bit_per_element_A: int, bit_per_element_B: int, bit_per_element_C: int): + bit_alignment_required_by_tma = 128 + return ((dim, bit_alignment_required_by_tma // bit_per_element_A), # A + (dim, bit_alignment_required_by_tma // bit_per_element_B), # B + (dim, bit_alignment_required_by_tma // bit_per_element_C)) # C + +def make_math_instruction_w_output(data_types: Tuple[DataType, DataType, DataType, DataType], + instruction_shape: Tuple[int, int, int]) -> (MathInstruction, DataType): + default_opcode = OpcodeClass.TensorOp + default_math_op = MathOperation.multiply_add + [A_data_type, B_data_type, Acc_data_type, Out_data_type] = data_types + return (MathInstruction( + instruction_shape, + A_data_type, B_data_type, Acc_data_type, + default_opcode, + default_math_op + ), Out_data_type) + +""" +Generate CUTLASS 3 convolution kernel(s) for SM100. + +This is meant to be called from GenerateSM100. +""" +def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, + log_indent_level: int = 0): + log_debug_line('GenerateSM100_TensorOp_16b_UMMA_conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + thor_sm = ThorSMRenumbering(cuda_version) + + minimum_compute_capability = 100 + maximum_compute_capability = thor_sm + + spatial_dims = [2, 3] + + conv_kinds = [ + ConvKind.Fprop, + ConvKind.Dgrad, + ConvKind.Wgrad + ] + + stages = 0 # zero means "deduce the number of stages automatically" + + data_types_and_instruction_shapes_1sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (64, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (64, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (64, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), + ] + math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_1sm) + + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] + + # tile_descriptions is a 2-level list. + # Each inner list is for each cluster shape. + for math_inst, output_type in math_instructions_w_output_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + cluster_multiplier = cluster_shape + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + # It's typical to get the data types from the math instruction. + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + for conv_kind in conv_kinds: + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + + data_types_and_instruction_shapes_2sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (256, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (256, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (256, 256, 16)), + ] + math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_2sm) + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] + + for math_inst, output_type in math_instructions_w_output_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + # It's typical to get the data types from the math instruction. + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + for conv_kind in conv_kinds: + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + +def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, + log_indent_level: int = 0): + # Instantiate Fp8 Fprop kernels with e4m3 A/B, f32 Acc, e4m3/bf16/f16/f32 C/D + log_debug_line('GenerateSM100_TensorOp_fp8_UMMA_conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + thor_sm = ThorSMRenumbering(cuda_version) + + minimum_compute_capability = 100 + maximum_compute_capability = thor_sm + + spatial_dims = [2, 3] + stages = 0 # zero means "deduce the number of stages automatically" + + data_types_and_instruction_shapes_1sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), + ] + math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_1sm) + + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] + + for math_inst, output_type in math_instructions_w_output_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + cluster_multiplier = cluster_shape + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = ConvKind.Fprop, + log_indent_level = log_indent_level) + + data_types_and_instruction_shapes_2sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (256, 256, 32)), + ] + math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_2sm) + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] + + for math_inst, output_type in math_instructions_w_output_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = ConvKind.Fprop, + log_indent_level = log_indent_level) + +def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 32] + ] + + tile_sizes = [ + [128, 128, 128] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + DataType.e2m1, + DataType.e2m3, + DataType.e3m2, + DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + def is_pingpong(kernel_schedule): + if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def tile_schedulers(sfdtype, kernel_schedule): + # Pingpong kernel schedule doesn't support stream-K. + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K + if is_pingpong(kernel_schedule): + return [TileSchedulerType.Default] + elif sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 + ] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type, kernel_schedule in product(data_types, kernel_schedules): + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledUniversal3x + ) + +def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 64] + ] + + tile_sizes_cooperative = [ + [128, 128, 128], + [128, 128, 256], + [256, 128, 128] + ] + + tile_sizes_pingpong = [ + [128, 128, 128], + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + DataType.e2m1 + ] + + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + + acc_types = [ DataType.f32 ] + + def is_pingpong(kernel_schedule): + if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \ + kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def is_nvf4(kernel_schedule): + if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \ + kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def tile_schedulers(sfdtype, kernel_schedule): + # Pingpong kernel schedule doesn't support stream-K. + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K + if is_pingpong(kernel_schedule): + return [TileSchedulerType.Default] + elif sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 + ] + + for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + for math_inst in math_instructions: + for kernel_schedule in kernel_schedules: + tile_descriptions = [] + tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative + for tile_size in tile_sizes: + # nvf4 kernel only supports ue4m3 SF + # mxf4 kernel only supports ue8m0 SF + if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \ + (math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)): + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledUniversal3x + ) + +def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 256], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] + ] + + tile_sizes = [ + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + warp_count = [4, 2, 1] + + acc_types = [ DataType.f32 ] + + instruction_sizes_mxf8f6f4 = [ + [16, 8, 64] + ] + + ab_types_mxf8f6f4 = [ + DataType.e2m1, + #DataType.e2m3, + DataType.e3m2, + #DataType.e5m2, + DataType.e4m3, + ] + + def tile_schedulers(kernel_schedule): + return [TileSchedulerType.Default] + + min_cc = 120 + max_cc = 121 + + kernel_schedules = [ + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120, + ] + + math_instructions_mxf8f6f4 = [] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_mxf8f6f4, ab_types_mxf8f6f4, ab_types_mxf8f6f4, acc_types): + math_instructions_mxf8f6f4.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add) + ) + + # Create gemm operator for mxf8f6f4 + for math_inst in math_instructions_mxf8f6f4: + tile_descriptions_mxf8f6f4 = [] + for tile_size in tile_sizes: + tile_descriptions_mxf8f6f4.append( + TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type, kernel_schedule in product(data_types, kernel_schedules): + # Set alignment d based on Destination format + for layout in layouts: + layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) + # Create gemm operator + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_mxf8f6f4, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(kernel_schedule), + gemm_kind = GemmKind.SparseUniversal3x) + +def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 16]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 16]] + ] + + cooperative_tile_sizes = [ + [128, 128, 128] + ] + pingpong_tile_sizes = [ + [64, 128, 128] + ] + + def get_tile_sizes(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return pingpong_tile_sizes + return cooperative_tile_sizes + + def get_warp_count(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return [2, 2, 1] + return [4, 2, 1] + + def get_sf_sizes(tile_size): + sf_sizes = [] + for vec_m in [1, 128]: + if tile_size[0] % vec_m > 0: + continue + for vec_n in [1, 128]: + if tile_size[1] % vec_m > 0: + continue + sf_sizes.append( + [vec_m, vec_n, 128] + ) + return sf_sizes + + cluster_shape = [1,1,1] + + acc_types = [ DataType.f32 ] + + instruction_sizes = [ + [16, 8, 32] + ] + + def tile_schedulers(kernel_schedule): + return [TileSchedulerType.Default] + + min_cc = 120 + max_cc = 121 + + kernel_schedulers = [ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120 + ] + + ab_types = [ + [DataType.e4m3, DataType.e4m3], + [DataType.e4m3, DataType.e5m2] + ] + + math_instructions = [] + + for instr_size, ab_type, acc_type in product(instruction_sizes, ab_types, acc_types): + a_type, b_type = ab_type + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + # Create gemm operator for mxf8f6f4 + for kernel_schedule in kernel_schedulers: + tile_sizes = get_tile_sizes(kernel_schedule) + warp_count = get_warp_count(kernel_schedule) + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + sf_sizes = get_sf_sizes(tile_size) + for sf_size in sf_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape, + explicit_vector_sizes=sf_size) + ) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type in data_types: + # Set alignment d based on Destination format + for layout in layouts: + layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) + # Create gemm operator + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(kernel_schedule), + gemm_kind = gemm_kind) + +def GenerateSM100(manifest, cuda_version): + arch_family_cc = ['100f', '101f', '103a'] + if CudaToolkitVersionSatisfies(cuda_version, 13, 0): + for old_cc, new_cc in [('101f', '110f')]: + arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc] + + # + # Dense Gemm + # + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) + + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): + GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) + # grouped GEMM + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + + # StreamK is included in regular generation + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) + + # Blockwise kernels + GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + + # + # Sparse Gemm + # + GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version) + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): + GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) + + # + # Block Scaled Gemm + # + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + # + # Conv + # + GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version) + GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version) + + +def GenerateSM120(manifest, cuda_version): + # StreamK is included in regular generation # + # + # Dense Block Scaled Gemm + # + GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + + # + # Sparse Gemm + # + GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + +################################################################################################### + +def GenerateSM90_Conv3x(manifest, cuda_version, + log_indent_level: int = 0): + """ + Generate CUTLASS 3 convolution kernel(s) for SM90. + + This is meant to be called from GenerateSM90. + """ + log_debug_line('GenerateSM90_Conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + minimum_compute_capability = 90 + maximum_compute_capability = 90 + + spatial_dims = (2, 3) + + # MMA shapes (MMA_M, MMA_N, MMA_K): + # + # Different hardware MMA instructions may have different MMA shapes. + # This function may generate kernels with different MMA shapes for + # different data types, either because the hardware only supports + # certain shapes for certain types, or for performance reasons + # (CUTLASS doesn't need to generate all valid kernels for the + # profiler library, just the best-performing ones). + # + # The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K) + # instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K, + # where 4, the "number of MMA instructions per tile," is determined + # through some combination of modeling and experiment. + # + # For performance on sm90, generally CUTLASS generates 64x128 + # instead of 128x64. + mma_64x64x16 = ( 64, 64, 16) + mma_64x64x8 = ( 64, 64, 8) + + num_mma_per_tile = 4 + + # Cluster shapes (1, 1, 1) and (2, 2, 1) are valid, + # but not included, because they tend not to perform as well. + cluster_shapes = ( + (2, 1, 1), + (1, 2, 1), + ) + + fp16 = DataType.f16 + bf16 = DataType.bf16 + fp32 = DataType.f32 + s8 = DataType.s8 + s32 = DataType.s32 + + # When generating kernels, the usual way is to specify 4 types, + # (A, B, Acc, C/D). Tests instead have 5 types, + # (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute), + # where ElementCompute is also called 'epi_type', + # and corresponds to the type of epilogue activations. + # This script maps tests' 5 types to 4 types + # by making ElementCompute the same as ElementOut. + + fp16_fp32_fp16_fp32 = { + 'a_type': fp16, # ElementAct(ivation) + 'b_type': fp16, # ElementF(i)lt(er) + 'c_type': fp32, # ElementAcc + 'd_type': fp32, # ElementOut (used only by CollectiveEpilogue) + 'acc_type': fp16, # ElementAcc + 'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue) + 'alignment_A': 8, # tma alignment elements of A + 'alignment_B': 8, # tma alignment elements of B + 'alignment_C': 4, # tma alignment elements of C + } + fp16_fp32_fp32_fp32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 4, + } + fp32_fp32_fp32_fp32 = { + 'a_type': fp32, + 'b_type': fp32, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 4, + 'alignment_B': 4, + 'alignment_C': 4, + } + s8_s32_s32_s32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + 'alignment_A': 16, + 'alignment_B': 16, + 'alignment_C': 4, + } + + # Other NVIDIA libraries may have the habit of specifying data types like this. + bf16bf16_bf16f32_f32 = { + 'a_type': bf16, + 'b_type': bf16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 4, + } + f16f16_f16f16_f16 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp16, + 'epi_type': fp16, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 8, + } + f16f16_f16f32_f32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 8, + } + f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32 + + i8i8_i8i32_f32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + 'alignment_A': 16, + 'alignment_B': 16, + 'alignment_C': 4, + } + + # Each element in the outermost iterable is one combination of + # + # (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes) + # + # for which to generate a kernel. spatial_dimension is the spatial + # dimension of the convolution: either 1, 2, or 3. byte_alignments + # is a triple of required minimum byte alignments for A, B, and C. + # + # Note that itertools functions produce a single-pass generator. + # The code doesn't need a multipass iterable, but if one did, one + # could call `tuple` or `list` on the generator. + # + # While this happens to use the same cluster sizes for each element, + # the code doesn't require that. Different convolution kinds, data + # types, or mma sizes might have different optimal cluster sizes. + combinations_of_parameters = chain( + # The following are all the kernels exercised in the unit tests. + # Please try to keep in sync with the unit tests. + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + s8_s32_s32_s32, + ), + ( + mma_64x64x16, + ), + cluster_shapes + ), + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp32_fp32_fp32_fp32, + ), + ( + mma_64x64x8, + ), + cluster_shapes + ), + product( + ( + ConvKind.Dgrad, + ConvKind.Wgrad + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + ), + ( + mma_64x64x16, + ), + cluster_shapes + ), + # Kernels not necessarily in the unit tests, but used elsewhere + # and thus useful to have generated for profiling. They may + # duplicate kernels above. All of them are 2-D. In general, + # CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the + # hardware permits 128 x 64. + ( + # Fprop + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f32f32_tf32f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 96, 8), (2, 1, 1)), + # + # i8i8_i8i32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 32), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 32), (2, 1, 1)), + # + # Dgrad + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), + ), + ) + + # SM >= 90 kernels don't actually use warp_count, but the + # TileDescription class needs it. The 4 in the default + # warp_count has nothing to do with num_mma_per_tile. + warp_count = [4, 1, 1] + + stages = 0 # zero means "deduce the number of stages automatically" + + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90 + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + schedule_pairs = ( + (mainloop_schedule, epilogue_schedule), + ) + tile_schedulers = ( + TileSchedulerType.Default, # -> void + ) + + def make_math_instruction(data_types: Dict[str, DataType], + mma_shape: Tuple[int, int, int]) -> MathInstruction: + default_opcode = OpcodeClass.TensorOp + default_math_op = MathOperation.multiply_add + return MathInstruction( + mma_shape, + data_types['a_type'], data_types['b_type'], data_types['c_type'], + default_opcode, + default_math_op + ) + + for (conv_kind, spatial_dim, data_types, mma_shape, cluster_shape) in combinations_of_parameters: + math_inst = make_math_instruction(data_types, mma_shape) + tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2]) + tile_description = TileDescription(tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, cluster_shape) + assert(isinstance(spatial_dim, int)) + dims_and_alignments = ( + ( + (spatial_dim, data_types['alignment_A']), + (spatial_dim, data_types['alignment_B']), + (spatial_dim, data_types['alignment_C']), + ), + ) + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = [tile_description], + data_types = data_types, + schedule_pairs = schedule_pairs, + tile_schedulers = tile_schedulers, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + +def GenerateSM90(manifest, cuda_version): + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_1684(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version) + GenerateSM90_Conv3x(manifest, cuda_version) + GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + +################################################################################################### + +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level + into the numeric identifier used in setting the log level. + + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {log_level}') + return numeric_level + +# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface +# to leverage the functionality in this file without running this script via a shell prompt. +def define_parser(): + parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") + parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)") + parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") + parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") + parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") + parser.add_argument("--architectures", default='53;60;61;70;75;80;90;100', help="Target compute architectures") + parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' + + 'Specifying this as \"all\" includes ALL the kernels, ' + + 'while not specifying this includes only the default set of kernels.') + parser.add_argument("--ignore-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. For backwards compatibility reasons, ' + + 'this option only takes effect if --kernels is set to a nonempty value.') + parser.add_argument("--exclude-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. In contrast to --ignore-kernels, ' + + 'this option always takes effect, ' + + 'whether or not --kernels is set to a nonempty value. ' + + 'It also can exclude kernels from the filter file ' + + '(see --kernel-filter-file option below).') + parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.') + parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") + parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') + parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list') + parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler') + parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000']) + parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list') + parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py') + parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, + help='Specify the output log file containing all enabled kernels in this build') + parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") + parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures") + parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, + help='Logging level to be used by the generator script') + parser.add_argument('--instantiation-level', type=str, default="", required=False, help="Instantiation level for SM90 kernels. Set to `max` and make sure `--kernels` is not empty to generate all possible configurations.") + _add_package_disablement_flag(parser) + return parser + + +if __name__ == "__main__": + parser = define_parser() + args = parser.parse_args() + + # Set the logging level based on the user-provided `--log-level` command-line option + logging.basicConfig(level=args.log_level) + + manifest = Manifest(args) + + archs = args.architectures.split(';') + + if args.heuristics_problems_file: + filter_manifest_and_write_heuristics_file(manifest, args) + + GenerateSM50(manifest, args.cuda_version) + GenerateSM60(manifest, args.cuda_version) + GenerateSM61(manifest, args.cuda_version) + GenerateSM70(manifest, args.cuda_version) + GenerateSM75(manifest, args.cuda_version) + GenerateSM80(manifest, args.cuda_version) + GenerateSM89(manifest, args.cuda_version) + GenerateSM90(manifest, args.cuda_version) + + blackwell_arch_list = [ + "100a", "100f", + "101a", "101f", + "103a", "103f", + "110a", "110f", + "120a", "120f", + "121a", "121f", + ] + blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs) + if blackwell_enabled_arch: + GenerateSM100(manifest, args.cuda_version) + GenerateSM120(manifest, args.cuda_version) + + if 'library' in args.generator_target.split(','): + manifest.emit(GeneratorTarget.Library) + + if 'kernel_testlist_l0' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L0") + + if 'kernel_testlist_l1' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L1") + + if args.selected_kernel_list is not None: + if len(manifest.selected_kernels) > 0: + with open(args.selected_kernel_list, 'w') as file_writer: + for line in manifest.selected_kernels: + file_writer.write("%s\n" % line) + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..83421a06427acdc3b059855991cf95a1d2f118b3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py @@ -0,0 +1,415 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for selecting CUTLASS library kernels based on problem description +""" +import json +import csv + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.generator import * + from cutlass_library.heuristics_provider import * +except ImportError: + from library import * + from generator import * + from heuristics_provider import * + +try: + from .sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) +except ImportError: + from sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) + +_LOGGER = logging.getLogger(__name__) + +dtype_map = {v: k for k, v in DataTypeNames.items()} + +def serialize_heuristics_results_to_json(problems_with_configs, outfile_path): + """ + Utilitiy function to write heuristics results to a json file for debug + + args: + problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict + outfile_path: Outfile path + + returns: + None + """ + pc_copy = problems_with_configs.copy() + for p in pc_copy: + for k, v in p.items(): + if isinstance(v, DataType): + p[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + p[k] = ShortLayoutTypeNames[v] + configs = p['configs'] + for c in configs: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + with open(outfile_path, 'w') as f: + json.dump(pc_copy, f, indent=2) + +def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None): + """ + Get heuristic-suggested GEMM kernel configurations for a single GEMM problem. + + args: + m, n, k: GEMM dimensions + batch_count: batch count + layouts: tuple of layouts of type LayoutType + use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions + count: Number of configs to return + provider: Heuristics provider to use + + returns: + A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys: + - 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size + - 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size + - 'stages': kernel pipeline stage count + - 'cluster_m', 'cluster_n', 'cluster_k': cluster size + - 'layout_a', 'layout_b': input tensor layouts of type LayoutType + - 'alignment_a', 'alignment_b': input tensor alignments, in count of elements + - 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType + - 'swizzle_size' : suggested threadblock swizzle + - 'split_k_slices': number of partitions of the k dimension for splitK + - 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n') + """ + if provider is None: + provider = MatmulHeuristics() + return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count) + +def get_gemm_configs(problems, provider=None, count=1): + """ + Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems. + + args: + problems: List of dictionaries describing GEMM problems with the following keys: + - 'm', 'n', 'k': Matrix dimensions (required) + - 'dtype_a': Data type of matrix A (required) + - 'dtype_b': Data type of matrix B (required) + - 'dtype_c': Data type of matrix C (default: None) + - 'dtype_d': Data type of matrix D (required) + - 'dtype_acc': Compute data type (default 'f32') + - 'layout': Operation layout (e.g. 'tnt') + - 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements) + - 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements) + - 'alpha': Scalar multiplier for A*B (default: 1.0) + - 'beta': Scalar multiplier for C (default: 0.0) + - 'batch_count': Number of GEMM operations in batch (default: 1) + - 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True) + provider: Heuristics provider to use + count: Number of configurations to return per problem (defualt: 1) + + returns: + A copy of the input dictionary, with key `configs` added containing the selected gemm configs + """ + ret = [] + + for problem in problems: + problem = problem.copy() + + try: + m = problem['m'] + n = problem['n'] + k = problem['k'] + dtype_a = problem['dtype_a'] + dtype_b = problem['dtype_b'] + dtype_d = problem['dtype_d'] + layout = problem['layout'] + except KeyError as e: + _LOGGER.error(f"Missing required parameter {e} for problem {problem}") + raise + + operation = problem.get('operation', 'gemm') + batch_count = problem.get('batch_count', 1) + dtype_acc = problem.get('dtype_acc', 'f32') + dtype_c = problem.get('dtype_c', None) + alpha = problem.get('alpha', 1.0) + beta = problem.get('beta', 0.0) + use_fast_acc = problem.get('use_fast_acc', True) + + if operation != OperationKindNames[OperationKind.Gemm]: + raise ValueError(f"Unsupported operation {operation}") + if not (len(layout) == 3 and all(c in "nt" for c in layout)): + raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}") + layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout) + + try: + dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()] + dtypes = tuple(dtype_map[dt] for dt in dtype_list) + except KeyError as dt: + _LOGGER.error(f"Unsupported data type: {dt}") + raise + + alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]]) + alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]]) + + configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider) + problem['configs'] = configs + + ret.append(problem) + + return ret + + +def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc = 100 + max_cc = 101 + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # nvMMH assumes 2sm instruction for !(cluster_m % 2) + is_2sm = config['cluster_m'] % 2 == 0 + instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4] + math_instruction = MathInstruction( + instruction_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = [ + { + "a_type" : math_instruction.element_a, + "b_type" : math_instruction.element_b, + "c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator, + "d_type" : element_d, + "acc_type" : math_instruction.element_accumulator, + "epi_type" : math_instruction.element_accumulator, + } + ] + + tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k']) + tile_description = TileDescription( + [instruction_shape[0] * tile_multiplier[0], + instruction_shape[1] * tile_multiplier[1], + instruction_shape[2] * 4 * tile_multiplier[2]], + 0, + [4,1,1], + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules = [] + if is_2sm: + schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]) + else: + schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]) + + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + + return configs, operations + + +def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc, max_cc = 90, 90 + + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + + is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128) + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_instr_shape = [0, 0, 0] + math_instruction = MathInstruction( + dummy_instr_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d) + if is_aligned: + layout = fix_alignments(data_types, layout, alignment_bits=128) + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_warp_count = [0, 0, 0] + tile_description = TileDescription( + [config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']], + 0, + dummy_warp_count, + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_description, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_types, + instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic + layout=layout, + gemm_kind=GemmKind.Universal3x, + enable_fp8_fast_acc=config['use_fast_acc'] + ) + + if len(schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + if len(stream_k_schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]): + configs.append(config) + operations.append(o) + + + return configs, operations + +def filter_manifest_and_write_heuristics_file(manifest, args): + """ + Prune a manifest according to heuristics suggestions from the problems file + + args: + manifest: Cutlass manifest to prune + args: generator.py args, requires: + - args.heuristics_problems_file + - args.heuristics_gpu + - args.heuristics_testlist_file + + returns: + A list of dictionaries, each of which has information about an operation and a problem from the input problems + """ + heuristics_problems = [] + with open(args.heuristics_problems_file, 'r') as f: + heuristics_problems = json.load(f) + gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu + mmh = MatmulHeuristics(gpu=gpu) + if any(('100' in arch) for arch in args.architectures.split(';')): + mmh.set_cta_div_n(64) + problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem) + + all_configs_and_operations = [] + operations = [] + for problem in problems_with_configs: + if any('90' in arch for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + + operations += problem_operations + problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'} + with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)] + all_configs_and_operations += with_problem_size + + for operation in operations: + manifest.add_kernel_filter(f"^{operation.procedural_name()}$") + if not all_configs_and_operations: + raise Exception("No valid configurations generated") + write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file) + return all_configs_and_operations + +def write_profiler_testlist_to_csv(configs_list, outfile_path): + """ + Write a list of configs to a testlist to be consumed by cutlass_profiler + + args: + configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries + outfile_path: Outfile path + + returns: + None + """ + profiler_testlist = configs_list.copy() + for c in profiler_testlist: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + + with open(outfile_path, mode='w', newline='') as ofile: + k_names = profiler_testlist[0].keys() + + writer = csv.DictWriter(ofile, fieldnames=k_names) + writer.writeheader() + writer.writerows(profiler_testlist) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..01a4112a34c87d73a792cce368fede96a9315ac1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py @@ -0,0 +1,175 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Providers for kernel selection heuristics +""" + +import sys +import os +import glob +import logging +import ctypes +import functools + + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import DataType, LayoutType +except ImportError: + from library import DataType, LayoutType + +class MatmulHeuristics: + + def __init__(self, gpu = None): + import nvMatmulHeuristics + self.mmh_lib = nvMatmulHeuristics + self.gpu = gpu + + if 'CUTLASS_NVMMH_SO_PATH' in os.environ: + nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH']) + else: + nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx + + self.lh = nvmmhInterfaceEx( + backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"], + flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING, + load_discovery_implicitly=True, + gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None + ) + self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"]) + + def _layout_from_cutlass(self, layouts): + assert(len(layouts)==3) + full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts) + input_layouts = full_layout_str[:2].upper() + lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR") + return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout] + + def _precision_from_cutlass_dtypes(self, dtypes): + dtype_to_cublas = { + DataType.f64: 'D', + DataType.f32: 'S', + DataType.f16: 'H', + DataType.bf16: 'T', + DataType.e4m3: 'Q', + DataType.e5m2: 'R', + DataType.s32: 'I', + DataType.s8: 'B', + } + + dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes + + a_c = dtype_to_cublas[dtype_a] + + if a_c.lower() != 'q': + return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + else: + return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + + def set_cta_div_n(self, div_n): + cta_n_div_requirement = ctypes.c_int(div_n) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, + ctypes.byref(cta_n_div_requirement), + ctypes.sizeof(cta_n_div_requirement) + ) + + def set_cta_div_m(self, div_m): + cta_m_div_requirement = ctypes.c_int(div_m) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, + ctypes.byref(cta_m_div_requirement), + ctypes.sizeof(cta_m_div_requirement) + ) + + def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1): + if use_fast_acc: + disable_fast_acc_for_fp8 = ctypes.c_int(0) + else: + disable_fast_acc_for_fp8 = ctypes.c_int(1) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8, + ctypes.byref(disable_fast_acc_for_fp8), + ctypes.sizeof(disable_fast_acc_for_fp8) + ) + + precision = self._precision_from_cutlass_dtypes(dtypes) + layout = self._layout_from_cutlass(layouts) + + matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count) + configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision) + + ret = [] + for c in configs: + kernel = c['kernel'] + problem = c['problem'] + + r = {} + r['estimated_runtime'] = c['runtime'] + r['cta_tile_m'] = kernel.cta_tile_m + r['cta_tile_n'] = kernel.cta_tile_n + r['cta_tile_k'] = kernel.cta_tile_k + r['instr_tile_m'] = kernel.instr_tile_m + r['instr_tile_n'] = kernel.instr_tile_n + r['instr_tile_k'] = kernel.instr_tile_k + r['warp_tile_m'] = kernel.warp_tile_m + r['warp_tile_n'] = kernel.warp_tile_n + r['warp_tile_k'] = kernel.warp_tile_k + r['cluster_m'] = kernel.cluster_m + r['cluster_n'] = kernel.cluster_n + r['cluster_k'] = 1 + r['layout_a'] = layouts[0] + r['layout_b'] = layouts[1] + r['layout_d'] = layouts[2] + r['dtype_a'] = dtypes[0] + r['dtype_b'] = dtypes[1] + r['dtype_acc'] = dtypes[2] + r['dtype_c'] = dtypes[3] + r['dtype_d'] = dtypes[4] + r['alignment_a'] = align_a + r['alignment_b'] = align_b + r['swizzle_size'] = kernel.swizzle_factor + r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n' + r['split_k_slices'] = kernel.split_k + r['use_fast_acc'] = use_fast_acc + r['voidC'] = voidC + + ret.append(r) + + return ret + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py new file mode 100644 index 0000000000000000000000000000000000000000..56d22dc4b0705b4813b15b1b09decf53b38f7f37 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py @@ -0,0 +1,1531 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Data types and tags used for emitting CUTLASS C++ kernels +""" + +import enum +import re + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + +################################################################################################### + +# +class GeneratorTarget(enum.Enum): + Library = enum_auto() +# +GeneratorTargetNames = { + GeneratorTarget.Library: 'library' +} +# + +################################################################################################### + +# +class DataType(enum.Enum): + void = enum_auto() # primarily used to disable C tensor for epilogues + b1 = enum_auto() + u2 = enum_auto() + u4 = enum_auto() + u8 = enum_auto() + u16 = enum_auto() + u32 = enum_auto() + u64 = enum_auto() + s2 = enum_auto() + s4 = enum_auto() + s8 = enum_auto() + s16 = enum_auto() + s32 = enum_auto() + s64 = enum_auto() + e4m3 = enum_auto() + e5m2 = enum_auto() + f8 = enum_auto() + f6 = enum_auto() + f4 = enum_auto() + e3m2 = enum_auto() + e2m3 = enum_auto() + e2m1 = enum_auto() + ue8m0 = enum_auto() + ue4m3 = enum_auto() + f16 = enum_auto() + bf16 = enum_auto() + f32 = enum_auto() + tf32 = enum_auto() + f64 = enum_auto() + cf16 = enum_auto() + cbf16 = enum_auto() + cf32 = enum_auto() + ctf32 = enum_auto() + cf64 = enum_auto() + cs2 = enum_auto() + cs4 = enum_auto() + cs8 = enum_auto() + cs16 = enum_auto() + cs32 = enum_auto() + cs64 = enum_auto() + cu2 = enum_auto() + cu4 = enum_auto() + cu8 = enum_auto() + cu16 = enum_auto() + cu32 = enum_auto() + cu64 = enum_auto() + invalid = enum_auto() + +# +ShortDataTypeNames = { + DataType.s32: 'i', + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', + DataType.f16: 'h', + DataType.f32: 's', + DataType.f64: 'd', + DataType.cf32: 'c', + DataType.cf64: 'z', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', +} + +# +DataTypeNames = { + DataType.void: "void", + DataType.b1: "b1", + DataType.u2: "u2", + DataType.u4: "u4", + DataType.u8: "u8", + DataType.u16: "u16", + DataType.u32: "u32", + DataType.u64: "u64", + DataType.s2: "s2", + DataType.s4: "s4", + DataType.s8: "s8", + DataType.s16: "s16", + DataType.s32: "s32", + DataType.s64: "s64", + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', + DataType.e2m3: 'e2m3', + DataType.e3m2: 'e3m2', + DataType.e2m1: 'e2m1', + DataType.ue8m0: 'ue8m0', + DataType.ue4m3: 'ue4m3', + DataType.f16: "f16", + DataType.bf16: "bf16", + DataType.f32: "f32", + DataType.tf32: "tf32", + DataType.f64: "f64", + DataType.cf16: "cf16", + DataType.cbf16: "cbf16", + DataType.cf32: "cf32", + DataType.ctf32: "ctf32", + DataType.cf64: "cf64", + DataType.cu2: "cu2", + DataType.cu4: "cu4", + DataType.cu8: "cu8", + DataType.cu16: "cu16", + DataType.cu32: "cu32", + DataType.cu64: "cu64", + DataType.cs2: "cs2", + DataType.cs4: "cs4", + DataType.cs8: "cs8", + DataType.cs16: "cs16", + DataType.cs32: "cs32", + DataType.cs64: "cs64", +} + +DataTypeTag = { + DataType.void: "void", + DataType.b1: "cutlass::uint1b_t", + DataType.u2: "cutlass::uint2b_t", + DataType.u4: "cutlass::uint4b_t", + DataType.u8: "uint8_t", + DataType.u16: "uint16_t", + DataType.u32: "uint32_t", + DataType.u64: "uint64_t", + DataType.s2: "cutlass::int2b_t", + DataType.s4: "cutlass::int4b_t", + DataType.s8: "int8_t", + DataType.s16: "int16_t", + DataType.s32: "int32_t", + DataType.s64: "int64_t", + DataType.e4m3: 'cutlass::float_e4m3_t', + DataType.e5m2: 'cutlass::float_e5m2_t', + DataType.f8: 'cutlass::type_erased_dynamic_float8_t', + DataType.f6: 'cutlass::type_erased_dynamic_float6_t', + DataType.f4: 'cutlass::type_erased_dynamic_float4_t', + DataType.e2m3: 'cutlass::float_e2m3_t', + DataType.e3m2: 'cutlass::float_e3m2_t', + DataType.e2m1: 'cutlass::float_e2m1_t', + DataType.ue8m0: 'cutlass::float_ue8m0_t', + DataType.ue4m3: 'cutlass::float_ue4m3_t', + DataType.f16: "cutlass::half_t", + DataType.bf16: "cutlass::bfloat16_t", + DataType.f32: "float", + DataType.tf32: "cutlass::tfloat32_t", + DataType.f64: "double", + DataType.cf16: "cutlass::complex", + DataType.cbf16: "cutlass::complex", + DataType.cf32: "cutlass::complex", + DataType.ctf32: "cutlass::complex", + DataType.cf64: "cutlass::complex", + DataType.cu2: "cutlass::complex", + DataType.cu4: "cutlass::complex", + DataType.cu8: "cutlass::complex", + DataType.cu16: "cutlass::complex", + DataType.cu32: "cutlass::complex", + DataType.cu64: "cutlass::complex", + DataType.cs2: "cutlass::complex", + DataType.cs4: "cutlass::complex", + DataType.cs8: "cutlass::complex", + DataType.cs16: "cutlass::complex", + DataType.cs32: "cutlass::complex", + DataType.cs64: "cutlass::complex", +} + +DataTypeSize = { + DataType.void: 0, + DataType.b1: 1, + DataType.u2: 2, + DataType.u4: 4, + DataType.u8: 8, + DataType.u16: 16, + DataType.u32: 32, + DataType.u64: 64, + DataType.s2: 2, + DataType.s4: 4, + DataType.s8: 8, + DataType.s16: 16, + DataType.s32: 32, + DataType.s64: 64, + DataType.e4m3: 8, + DataType.e5m2: 8, + DataType.f8: 8, + DataType.f6: 6, + DataType.f4: 4, + DataType.e2m3: 6, + DataType.e3m2: 6, + DataType.e2m1: 4, + DataType.ue8m0: 8, + DataType.ue4m3: 8, + DataType.f16: 16, + DataType.bf16: 16, + DataType.f32: 32, + DataType.tf32: 32, + DataType.f64: 64, + DataType.cf16: 32, + DataType.cbf16: 32, + DataType.cf32: 64, + DataType.ctf32: 32, + DataType.cf64: 128, + DataType.cu2: 4, + DataType.cu4: 8, + DataType.cu8: 16, + DataType.cu16: 32, + DataType.cu32: 64, + DataType.cu64: 128, + DataType.cs2: 4, + DataType.cs4: 8, + DataType.cs8: 16, + DataType.cs16: 32, + DataType.cs32: 64, + DataType.cs64: 128, +} + +################################################################################################### +# +class BlasMode(enum.Enum): + symmetric = enum_auto() + hermitian = enum_auto() + +# +BlasModeTag = { + BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', + BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', +} + +# +class ComplexTransform(enum.Enum): + none = enum_auto() + conj = enum_auto() + +# +ComplexTransformTag = { + ComplexTransform.none: 'cutlass::ComplexTransform::kNone', + ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', +} + +# Used for cutlass3x complex kernel collective mainloop builder instantiation +ComplexTransformTag3x = { + ComplexTransform.none: 'cute::identity', + ComplexTransform.conj: 'cute::conjugate', +} + +# +RealComplexBijection = [ + (DataType.f16, DataType.cf16), + (DataType.f32, DataType.cf32), + (DataType.f64, DataType.cf64), +] + +# +def is_complex(data_type): + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + +def is_block_scaled(gemm_kind): + return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x) + +def is_blockwise(gemm_kind): + return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) + +def is_grouped(gemm_kind): + return gemm_kind in (GemmKind.GroupedUniversal3x, + GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) + +# +def get_complex_from_real(real_type): + for r, c in RealComplexBijection: + if real_type == r: + return c + return DataType.invalid + +# +def get_real_from_complex(complex_type): + for r, c in RealComplexBijection: + if complex_type == c: + return r + return DataType.invalid + +# TMA requires an alignment of 128 bits for all data types +def get_tma_alignment(data_type): + if data_type == DataType.void: + return 0 + elif DataTypeSize[data_type] == 6: + return 128 # 96B alignment for 16U6 format + else: + return 128 // DataTypeSize[data_type] + +# +class ComplexMultiplyOp(enum.Enum): + multiply_add = enum_auto() + gaussian = enum_auto() + +################################################################################################### + +# +class MathOperation(enum.Enum): + multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + multiply_add_mixed_input_upcast = enum_auto() + xor_popc = enum_auto() + and_popc = enum_auto() + multiply_add_fast_bf16 = enum_auto() + multiply_add_fast_f16 = enum_auto() + multiply_add_fast_f32 = enum_auto() + multiply_add_complex_fast_f32 = enum_auto() + multiply_add_complex = enum_auto() + multiply_add_complex_gaussian = enum_auto() + multiply_add_fast_accum = enum_auto() + +# +MathOperationTag = { + MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', + MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', + MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', + MathOperation.and_popc: 'cutlass::arch::OpAndPopc', + MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', + MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', + MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', + MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', + MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', + MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', + MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum', +} + +################################################################################################### + +# +class LayoutType(enum.Enum): + ColumnMajor = enum_auto() + RowMajor = enum_auto() + ColumnMajorInterleaved2 = enum_auto() + RowMajorInterleaved2 = enum_auto() + ColumnMajorInterleaved32 = enum_auto() + RowMajorInterleaved32 = enum_auto() + ColumnMajorInterleaved64 = enum_auto() + RowMajorInterleaved64 = enum_auto() + TensorNWC = enum_auto() + TensorNHWC = enum_auto() + TensorNDHWC = enum_auto() + TensorNCHW = enum_auto() + TensorNGHWC = enum_auto() + TensorNC32HW32 = enum_auto() + TensorNC64HW64 = enum_auto() + TensorC32RSK32 = enum_auto() + TensorC64RSK64 = enum_auto() + TensorKCS = enum_auto() + TensorKCSR = enum_auto() + TensorKCSRT = enum_auto() + +# +LayoutTag = { + LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', + LayoutType.RowMajor: 'cutlass::layout::RowMajor', + LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', + LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', + LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', + LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', + LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', + LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', + LayoutType.TensorNWC: 'cutlass::layout::TensorNWC', + LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', + LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', + LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', + LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', + LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', + LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', + LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', + LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', + LayoutType.TensorKCS: 'cutlass::layout::TensorKCS', + LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR', + LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT' +} + +# +TransposedLayout = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, + LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, + LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, + LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, + LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, + LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, + LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, + LayoutType.TensorNHWC: LayoutType.TensorNHWC +} + +# +ShortLayoutTypeNames = { + LayoutType.ColumnMajor: 'n', + LayoutType.ColumnMajorInterleaved2: 'n2', + LayoutType.ColumnMajorInterleaved32: 'n32', + LayoutType.ColumnMajorInterleaved64: 'n64', + LayoutType.RowMajor: 't', + LayoutType.RowMajorInterleaved2: 't2', + LayoutType.RowMajorInterleaved32: 't32', + LayoutType.RowMajorInterleaved64: 't64', + LayoutType.TensorNWC: 'nwc', + LayoutType.TensorNHWC: 'nhwc', + LayoutType.TensorNDHWC: 'ndhwc', + LayoutType.TensorNCHW: 'nchw', + LayoutType.TensorNGHWC: 'nghwc', + LayoutType.TensorNC32HW32: 'nc32hw32', + LayoutType.TensorNC64HW64: 'nc64hw64', + LayoutType.TensorC32RSK32: 'c32rsk32', + LayoutType.TensorC64RSK64: 'c64rsk64', + LayoutType.TensorKCS: 'kcs', + LayoutType.TensorKCSR: 'kcsr', + LayoutType.TensorKCSRT: 'kcsrt' +} + +# +ShortComplexLayoutNames = { + (LayoutType.ColumnMajor, ComplexTransform.none): 'n', + (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', + (LayoutType.RowMajor, ComplexTransform.none): 't', + (LayoutType.RowMajor, ComplexTransform.conj): 'h' +} + +################################################################################################### +class KernelScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + Multistage = enum_auto() + CpAsyncWarpSpecialized = enum_auto() + CpAsyncWarpSpecializedPingpong = enum_auto() + CpAsyncWarpSpecializedCooperative = enum_auto() + Tma = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecializedFP8FastAccum = enum_auto() + TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() + ImplicitTmaWarpSpecializedSm90 = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() + PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() + + BlockwiseTmaWarpSpecializedCooperative = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() + BlockwiseTmaWarpSpecializedPingpong = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto() + + TmaWarpSpecialized1SmSm100 = enum_auto() + TmaWarpSpecialized2SmSm100 = enum_auto() + ImplicitTmaWarpSpecialized1SmSm100 = enum_auto() + ImplicitTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + + SparseTmaWarpSpecialized1SmSm100 = enum_auto() + SparseTmaWarpSpecialized2SmSm100 = enum_auto() + + BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() + BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + + BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() + BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() + + + Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() + + # FP4 Ultra + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() + Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto() + Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto() + + F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() + + BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() + BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto() + +KernelScheduleTag = { + KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', + KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', + KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative', + KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', + KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', + KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', + KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', + + KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100', + KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100', + + KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100', + KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100', + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100', + + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100', + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100', + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', + + # FP4 Ultra + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120', + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120', + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120', + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120', + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120', + + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120', +} + +# +KernelScheduleSuffixes = { + KernelScheduleType.ScheduleAuto: '', + KernelScheduleType.Multistage: '_cpasync', + KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative', + KernelScheduleType.Tma: '_unspecialized', + KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', + + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q', + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16', + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs16', + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32', + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32', + + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q' +} + +class EpilogueScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + EpilogueTransposed = enum_auto() + NoSmemWarpSpecialized = enum_auto() + PtrArrayNoSmemWarpSpecialized = enum_auto() + NoSmemWarpSpecialized1Sm = enum_auto() + NoSmemWarpSpecialized2Sm = enum_auto() + FastF32NoSmemWarpSpecialized1Sm = enum_auto() + FastF32NoSmemWarpSpecialized2Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecialized1Sm = enum_auto() + TmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecialized1Sm = enum_auto() + PtrArrayTmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() + +# +EpilogueScheduleTag = { + EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', + EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', + EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', + EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', + EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', + EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong', +} + +# +EpilogueScheduleSuffixes = { + EpilogueScheduleType.ScheduleAuto: '', + EpilogueScheduleType.EpilogueTransposed: '', + EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized1Sm: '', + EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', +} + +class EpilogueFunctor3x(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationBlockScaleFactor = enum_auto() + +# +EpilogueFunctor3xTag = { + EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', + EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor', +} + +# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type) +def is_tma_epilogue(epilogue_schedule_type): + return epilogue_schedule_type in [ + EpilogueScheduleType.ScheduleAuto, + EpilogueScheduleType.TmaWarpSpecialized, + EpilogueScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + ] + +def to_grouped_schedule(schedule, grouped): + if not grouped: + return schedule + + group_schedule_map = { + # SM90 + KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, + # SM100 + KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100, + KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100, + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100, + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100, + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, + EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, + # SM103 + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, + } + + return group_schedule_map[schedule] + +class TileSchedulerType(enum.Enum): + Default = enum_auto() + Persistent = enum_auto() + StreamK = enum_auto() +# +TileSchedulerTag = { + TileSchedulerType.Default: 'void', + TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', + TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', +} + +# +TileSchedulerSuffixes = { + TileSchedulerType.Default: '', + TileSchedulerType.Persistent: '', + TileSchedulerType.StreamK: '_stream_k', +} + +################################################################################################### + +# +class SideMode(enum.Enum): + Left = enum_auto() + Right = enum_auto() + +# +SideModeTag = { + SideMode.Left: 'cutlass::SideMode::kLeft', + SideMode.Right: 'cutlass::SideMode::kRight' +} + +# +ShortSideModeNames = { + SideMode.Left: 'ls', + SideMode.Right: 'rs' +} + +################################################################################################### + +# +class FillMode(enum.Enum): + Lower = enum_auto() + Upper = enum_auto() + +# +FillModeTag = { + FillMode.Lower: 'cutlass::FillMode::kLower', + FillMode.Upper: 'cutlass::FillMode::kUpper' +} + +# +ShortFillModeNames = { + FillMode.Lower: 'l', + FillMode.Upper: 'u' +} + +################################################################################################### + +# +class DiagType(enum.Enum): + NonUnit = enum_auto() + Unit = enum_auto() + +# +DiagTypeTag = { + DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', + DiagType.Unit: 'cutlass::DiagType::kUnit' +} + +# +ShortDiagTypeNames = { + DiagType.NonUnit: 'nu', + DiagType.Unit: 'un' +} + +################################################################################################### + +# +class OpcodeClass(enum.Enum): + Simt = enum_auto() + TensorOp = enum_auto() + WmmaTensorOp = enum_auto() + SparseTensorOp = enum_auto() + BlockScaledTensorOp = enum_auto() + + +OpcodeClassNames = { + OpcodeClass.Simt: 'simt', + OpcodeClass.TensorOp: 'tensorop', + OpcodeClass.WmmaTensorOp: 'wmma_tensorop', + OpcodeClass.SparseTensorOp: 'sptensorop', + OpcodeClass.BlockScaledTensorOp: 'bstensorop' +} + +OpcodeClassTag = { + OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', + OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', + OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', + OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp', + OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp' +} + +################################################################################################### + +# +class OperationKind(enum.Enum): + Gemm = enum_auto() + RankK = enum_auto() + Rank2K = enum_auto() + Trmm = enum_auto() + Symm = enum_auto() + Conv2d = enum_auto() + Conv3d = enum_auto() + +# +OperationKindNames = { + OperationKind.Gemm: 'gemm' + , OperationKind.RankK: 'rank_k' + , OperationKind.Rank2K: 'rank_2k' + , OperationKind.Trmm: 'trmm' + , OperationKind.Symm: 'symm' + , OperationKind.Conv2d: 'conv2d' + , OperationKind.Conv3d: 'conv3d' +} + +# +class Target(enum.Enum): + library = enum_auto() +# +ArchitectureNames = { + 50: 'maxwell', + 60: 'pascal', + 61: 'pascal', + 70: 'volta', + 75: 'turing', + 80: 'ampere', + 89: 'ada', + 90: 'hopper' +} + +# +SharedMemPerCC = { + 70: 96, # 96KB of SMEM + 72: 96, # 96KB of SMEM + 75: 64, # 64KB of SMEM + 80: 163, # 163KB of SMEM - 1KB reserved for the driver + 86: 99, # 99KB of SMEM - 1KB reserved for the driver + 87: 163, # 163KB of SMEM - 1KB reserved for the driver + 89: 99, # 99KB of SMEM - 1KB reserved for the driver + 90: 227, # 227KB of SMEM - 1KB reserved for the driver + 100: 227, # 227KB of SMEM - 1KB reserved for the driver +} + +################################################################################################### + +# +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + +################################################################################################### + +# +class GemmKind(enum.Enum): + Gemm = enum_auto() + Sparse = enum_auto() + Universal = enum_auto() + Universal3x = enum_auto() + SparseUniversal3x = enum_auto() + PlanarComplex = enum_auto() + PlanarComplexArray = enum_auto() + Grouped = enum_auto() + BlockScaledUniversal3x = enum_auto() + GroupedUniversal3x = enum_auto() + GroupedBlockScaledUniversal3x = enum_auto() + BlockwiseUniversal3x = enum_auto() + GroupedBlockwiseUniversal3x = enum_auto() + +# +GemmKindNames = { + GemmKind.Gemm: "gemm", + GemmKind.Sparse: "spgemm", + GemmKind.Universal: "gemm", + GemmKind.Universal3x: "gemm", + GemmKind.SparseUniversal3x: "spgemm", + GemmKind.PlanarComplex: "gemm_planar_complex", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", + GemmKind.Grouped: "gemm_grouped", + GemmKind.BlockScaledUniversal3x: "gemm", + GemmKind.GroupedUniversal3x: "gemm_grouped", + GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped", + GemmKind.BlockwiseUniversal3x: "gemm", + GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped" +} + +# +class RankKKind(enum.Enum): + Universal = enum_auto() + +# +RankKKindNames = { + RankKKind.Universal: "rank_k" +} + +# +class TrmmKind(enum.Enum): + Universal = enum_auto() + +# +TrmmKindNames = { + TrmmKind.Universal: "trmm" +} + +# +class SymmKind(enum.Enum): + Universal = enum_auto() + +# +SymmKindNames = { + SymmKind.Universal: "symm" +} + +# +class EpilogueFunctor(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationClamp = enum_auto() + +# +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', + EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', +} + +# +class MixedInputMode(enum.Enum): + ConvertOnly = enum_auto() + ScaleOnly = enum_auto() + ScaleWithZeroPoint = enum_auto() + +# +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + Horizontal = enum_auto() + StridedDgradIdentity1 = enum_auto() + StridedDgradIdentity4 = enum_auto() + StridedDgradHorizontal = enum_auto() + StreamK = enum_auto() + +# +SwizzlingFunctorTag = { + SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', + SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', + SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', + SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', + SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', +} + +# +class GroupScheduleMode(enum.Enum): + Device = enum_auto(), + Host = enum_auto() + +# +GroupScheduleModeTag = { + GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', + GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' +} + +# +ShortGroupScheduleModeNames = { + GroupScheduleMode.Device: 'Device', + GroupScheduleMode.Host: 'Host' +} + +################################################################################################### + +# +class ConvKind(enum.IntEnum): + Fprop = 0 + Dgrad = 1 + Wgrad = 2 + +# +ConvKindTag = { + ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', + ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', + ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' +} + +ConvKindNames = { + ConvKind.Fprop: 'fprop', + ConvKind.Dgrad: 'dgrad', + ConvKind.Wgrad: 'wgrad', +} + +class ConvMode(enum.IntEnum): + CrossCorrelation = 0 + Convolution = 1 + +# +class IteratorAlgorithm(enum.Enum): + Analytic = 0 + Optimized = 1 + FixedChannels = 2 + FewChannels = 3 + FixedStrideDilation = 4 + +# +IteratorAlgorithmTag = { + IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', + IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', + IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', + IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', + IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' +} + +IteratorAlgorithmNames = { + IteratorAlgorithm.Analytic: 'analytic', + IteratorAlgorithm.Optimized: 'optimized', + IteratorAlgorithm.FixedChannels: 'fixed_channels', + IteratorAlgorithm.FewChannels: 'few_channels', + IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' +} + +# +class StrideSupport(enum.Enum): + Strided = 0 + Unity = 1 + Fixed = 2 + +# +StrideSupportTag = { + StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', + StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', + StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' +} + +StrideSupportNames = { + StrideSupport.Strided: '', + StrideSupport.Unity: 'unity_stride', + StrideSupport.Fixed: 'fixed_stride' +} + +# +class GroupMode(enum.Enum): + NoneGroup = enum_auto() # dense conv (G=1) + SingleGroup = enum_auto() # grouped convolution (single group per CTA) + MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) + Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) + +# +GroupModeTag = { + GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', + GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', + GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', + GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', +} + +GroupModeNames = { + GroupMode.NoneGroup: '', + GroupMode.SingleGroup: 'single_group', + GroupMode.MultipleGroup: 'multiple_group', + GroupMode.Depthwise: 'depthwise', +} + +DynamicClusterShape = [0, 0, 1] + +################################################################################################### + +# +class MathInstruction: + def __init__(self, + instruction_shape, \ + element_a, element_b, element_accumulator, \ + opcode_class, math_operation = MathOperation.multiply_add \ + , element_scale_factor = None + ): + + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + self.element_scale_factor = element_scale_factor + +# +class TileDescription: + + def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None): + self.threadblock_shape = threadblock_shape + self.tile_shape = threadblock_shape + self.stages = stages + self.warp_count = warp_count + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + self.cluster_shape = cluster_shape + self.explicit_vector_sizes = explicit_vector_sizes + + def procedural_name(self): + if self.minimum_compute_capability >= 90: + return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( + tbm = self.threadblock_shape[0], + tbn = self.threadblock_shape[1], + tbk = self.threadblock_shape[2], + cm = self.cluster_shape[0], + cn = self.cluster_shape[1], + ck = self.cluster_shape[2], + s = self.stages) + else: + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name + +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name + +# +class TensorDescription: + def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): + self.element = element + self.layout = layout + self.alignment = alignment + self.complex_transform = complex_transform + +# +class SymmetricTensorDescription: + def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): + self.element = element + self.layout = layout + self.fill_mode = fill_mode + self.alignment = alignment + self.complex_transform = complex_transform + self.side_mode = side_mode + +# +class TriangularTensorDescription: + def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): + self.element = element + self.layout = layout + self.side_mode = side_mode + self.fill_mode = fill_mode + self.diag_type = diag_type + self.alignment = alignment + self.complex_transform = complex_transform + +# +def CalculateSmemUsage(operation): + cta_shape = operation.tile_description.threadblock_shape + stages = operation.tile_description.stages + + if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: + # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) + if DataTypeSize[operation.A.element] == 32: + elements_per_8b_md = 2 + elif DataTypeSize[operation.A.element] == 4: + elements_per_8b_md = 8 + else: + elements_per_8b_md = 4 + + smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md + else: + # Few BLAS3 operations only have A tensor + data_type_size_a = DataTypeSize[operation.A.element] + data_type_size_b = DataTypeSize[operation.A.element] + if operation.is_mixed_input(): + data_type_size_b = DataTypeSize[operation.B.element] + + smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ + data_type_size_b * cta_shape[1] * cta_shape[2] // 8 + + smem_usage = smem_per_stage * stages + return (smem_usage >> 10) + + +class GemmUniversalMode(enum.IntEnum): + """ + Types corresponding to GemmUniversalMode + """ + Gemm = 0 + GemmSplitKParallel = 1 + Batched = 2 + Array = 3 + + +class SplitKMode(enum.IntEnum): + """ + Types corresponding to SplitKMode + """ + NoneSplitK = 0 + Serial = 1 + Parallel = 2 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..5733ef26322794ee650dfa0c8c2b170bd8c6f3e5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py @@ -0,0 +1,868 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for filtering CUTLASS library kernels and emitting library intitialization +and building code +""" + +import enum +import logging +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.gemm_operation import * + from cutlass_library.rank_k_operation import * + from cutlass_library.rank_2k_operation import * + from cutlass_library.trmm_operation import * + from cutlass_library.symm_operation import * + from cutlass_library.conv2d_operation import * + from cutlass_library.conv3d_operation import * +except ImportError: + from library import * + from gemm_operation import * + from rank_k_operation import * + from rank_2k_operation import * + from trmm_operation import * + from symm_operation import * + from conv2d_operation import * + from conv3d_operation import * + +################################################################################################### +_LOGGER = logging.getLogger(__name__) + + +class EmitOperationKindAll: + """ + Emit the OperationKind-level CUTLASS library initialization code. + The code is generated in the {generated_path}/{operation_kind} directory + (e.g., tools/library/generated/gemm in the build directory, + for OperationKind=Gemm), in the all_{operation_kind}_operations.cu file + (e.g., all_gemm_operations.cu for OperationKind=Gemm). + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_{configuration_name}(Manifest& manifest); + + The file also _defines_ the following function in that namespace. + + void initialize_all_{operation_kind}_operations(Manifest& manifest); + + That function calls all of the functions declared in this file. + Those functions are defined in subdirectories + (which this class does not create). + """ + + def __init__(self, generated_path, kind, args): + self.generated_path = generated_path + self.kind = kind + self.args = args + + self.header_template =""" +/* + Generated by manifest.py - Do not edit. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.entry_template = """ + +// +// Entry point to construct operations +// +void initialize_all_${operation_name}_operations(Manifest &manifest) { +""" + self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" + self.configuration_template =" initialize_${configuration_name}(manifest);\n" + + self.epilogue_template ="""} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +""" + + # + def __enter__(self): + _LOGGER.debug("*** EmitOperationKindAll::__enter__") + + self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind]) + _LOGGER.debug('*** operation_path (directory to create): ' + + str(self.operation_path)); + os.makedirs(self.operation_path, exist_ok=True) + + self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) + + self.source_files = [self.top_level_path,] + + self.configurations = [] + + return self + + # + def emit(self, operations): + _LOGGER.debug('*** EmitOperationKindAll::emit') + _LOGGER.debug(f"*** len(operations): {len(operations)}") + _LOGGER.debug(f"*** min_cc list: {sorted(min_cc for min_cc, _ in operations.items())}") + + for min_cc, configurations in sorted(operations.items()): + _LOGGER.debug(f"*** min_cc={min_cc}") + + for configuration_name, _ in configurations.items(): + _LOGGER.debug(f"*** configuration_name={configuration_name}") + self.configurations.append(configuration_name) + self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitOperationKindAll::__exit__") + + self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]})) + + for configuration_name in self.configurations: + self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name})) + + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() + + +class EmitOperationKindLibrary: + """ + Emit the CUTLASS library initialization code for each OperationKind. + The code is generated in the directory + {generated_path}/{operation_kind}/{min_cc} + (e.g., tools/library/generated/gemm/90 in the build directory, + for min_cc=90 and OperationKind=Gemm), in the file + all_sm{min_cc}_{operation_kind}_operations.cu + (e.g., all_sm90_gemm_operations.cu for min_cc=90 and OperationKind=Gemm). + The min_cc variable here indicates the minimum GPU architecture version + that the things to be initialized require. + For example, min_cc=90 indicates sm90. + + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_all_sm{min_cc}_{subclass_name}_{extended_name}_operations(Manifest& manifest); + + where extended_name is operation.extended_name() for all the operations + given to the emit method (which see below). (All operations for a given + configuration_name are guaranteed to have the same extended_name().) + + The file also _defines_ the following function in that namespace. + + void initialize_all_sm{min_cc}__{operation_kind}_operations(Manifest& manifest); + + That function calls all of the functions declared in this file. + Those functions are defined in subdirectories. + The mapping from OperationKind to emitter handles the details + of what happens in each of those subdirectories. + """ + + def __init__(self, generated_path, min_cc, kind, args): + self.generated_path = generated_path + self.min_cc = min_cc + self.kind = kind + self.args = args + self.emitters = { + OperationKind.Gemm: EmitGemmConfigurationLibrary, + OperationKind.Conv2d: EmitConv2dConfigurationLibrary, + OperationKind.Conv3d: EmitConv3dConfigurationLibrary, + OperationKind.RankK: EmitRankKConfigurationLibrary, + OperationKind.Rank2K: EmitRank2KConfigurationLibrary, + OperationKind.Trmm: EmitTrmmConfigurationLibrary, + OperationKind.Symm: EmitSymmConfigurationLibrary + } + + self.header_template =""" +/* + Generated by manifest.py - Do not edit. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + self.entry_template = """ + +// +// Entry point to construct operations +// +void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) { +""" + self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" + self.configuration_template = " initialize_${configuration_name}(manifest);\n" + self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n" + self.subclass_prototype_template = "void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n" + self.epilogue_template ="""} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +""" + + # + def __enter__(self): + _LOGGER.debug("*** EmitOperationKindLibrary::__enter__") + _LOGGER.debug(f"*** generated_path: {str(self.generated_path)}") + _LOGGER.debug(f"*** OperationKindNames[kind]: {OperationKindNames[self.kind]}") + _LOGGER.debug(f"*** min_cc: {self.min_cc}") + + self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind], str(self.min_cc)) + _LOGGER.debug(f"*** operation_path (directory to make): {str(self.operation_path)}") + os.makedirs(self.operation_path) + + self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) + + self.source_files = {} + + # Each {operation_kind x cc} combination is further decomposed by the instruction + # types used. This dictionary used to track the file handles for the top-level + # files of each subclass + self.subclass_files = {} + + # Configurations in each sub class + self.subclass_configurations = {} + + return self + + # + def emit(self, configuration_name, operations): + _LOGGER.debug("*** EmitOperationKindLibrary::emit") + _LOGGER.debug(f"*** configuration_name: {configuration_name}") + + assert len(operations) > 0 + + # The extended name for all operations of a given configuration_name is guaranteed + # to be the same because extended_name() is used in defining configuration_name. Thus, + # we can safely use the extended_name() of the first operation. + extended_name = operations[0].extended_name() + _LOGGER.debug('*** extended_name (for all ops): ' + extended_name) + + # Create a directory for operations with this subclass if it does not exist + if extended_name not in self.subclass_files: + subclass_path = os.path.join(self.operation_path, extended_name) + _LOGGER.debug(f"*** subclass_path: {str(subclass_path)}") + os.mkdir(subclass_path) + + self.subclass_configurations[extended_name] = [] + + # Open a new top-level file for this sub class + subclass_top_level_path = os.path.join( + subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug('*** subclass_top_level_path (min_cc, extended_name, ' + + 'OperationKind): ' + str(subclass_top_level_path)) + + self.subclass_files[extended_name] = open(subclass_top_level_path, "w") + self.subclass_files[extended_name].write(self.header_template) + + self.source_files[extended_name] = [subclass_top_level_path] + + subclass_dir = os.path.dirname(self.subclass_files[extended_name].name) + _LOGGER.debug('*** subclass_dir: ' + str(subclass_dir)) + + with self.emitters[self.kind](subclass_dir, configuration_name) as configuration_emitter: + for operation in operations: + configuration_emitter.emit(operation) + + _LOGGER.debug('*** configuration_emitter.configuration_path: ' + + str(configuration_emitter.configuration_path)) + self.source_files[extended_name].append(configuration_emitter.configuration_path) + + self.subclass_configurations[extended_name].append(configuration_name) + self.subclass_files[extended_name].write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitOperationKindLibrary::__exit__") + for subclass_name, subclass_file in sorted(self.subclass_files.items()): + subclass_cfg = { + 'min_cc': str(self.min_cc), + 'subclass_name': subclass_name, + 'operation_name': OperationKindNames[self.kind] + } + self.top_level_file.write(SubstituteTemplate(self.subclass_prototype_template, subclass_cfg)) + + self.top_level_file.write( + SubstituteTemplate(self.entry_template, { + 'min_cc': str(self.min_cc), + 'subclass_name': '', + 'operation_name': OperationKindNames[self.kind] + })) + + # Finish and close all subclass files + for subclass_name, subclass_file in sorted(self.subclass_files.items()): + subclass_cfg = { + 'min_cc': str(self.min_cc), + 'subclass_name': subclass_name, + 'operation_name': OperationKindNames[self.kind] + } + subclass_file.write(SubstituteTemplate(self.entry_template, subclass_cfg)) + + for configuration in self.subclass_configurations[subclass_name]: + subclass_file.write( + SubstituteTemplate(self.configuration_template, { + 'configuration_name': configuration + })) + + subclass_file.write(self.epilogue_template) + subclass_file.close() + + # Write the call to initialize_all for this subclass to the top-level file + self.top_level_file.write(SubstituteTemplate(self.subclass_call_template, subclass_cfg)) + + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() + +class EmitInterfaceLibrary: + """ + Emit the topmost-level CUTLASS library initialization code. + The code is generated in the generated_path directory + (e.g., tools/library/generated in the build directory), + in the initialize_all.cpp file. + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_all_{operation_kind}_operations(Manifest& manifest); + + where {operation_kind} abbreviates the "kind" of operation + (e.g., gemm for matrix-matrix multiply, conv2d for 2-d convolution, + or trmm for triangular solve with multiple right-hand sides). + The definitions of these functions live in subdirectories. + + The file also _defines_ the following function in that namespace. + + void initialize_all(Manifest& manifest); + + That function first prepares the manifest, and then + calls all of the functions declared in this file. + """ + + def __init__(self, generated_path, operation_count, args): + self.generated_path = generated_path + self.args = args + + self.prototypes = [] + self.fn_calls = [] + self.operation_count = str(operation_count) + + self.top_level_hdr_template = ''' +/* + Generated by manifest.py - Do not edit. +*/ +''' + self.top_level_prologue = ''' + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +\tnamespace library { + +${prototypes} +''' + + self.top_level_initialize_kind = ''' +\t\tvoid initialize_all_${kind}_operations(Manifest &manifest) { +${fn_calls} +\t\t} +''' + + self.top_level_initialize = ''' +\t\tvoid initialize_all(Manifest &manifest) { +\t\t\tmanifest.reserve(${operation_count});\n +${fn_calls} +\t\t} +''' + + self.top_level_suffix = ''' +\t} // namespace library +} // namespace cutlass + +''' + + # + def __enter__(self): + _LOGGER.debug("*** EmitInterfaceLibrary::__enter__") + + self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp') + _LOGGER.debug("*** top_level_path: " + str(self.top_level_path)) + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.top_level_hdr_template) + + self.source_files = [self.top_level_path,] + + return self + + # + def emit(self, operation_name): + _LOGGER.debug("*** EmitInterfaceLibrary::emit") + _LOGGER.debug("*** operation_name: " + operation_name) + + self.prototypes.append(SubstituteTemplate( + "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);", + {'operation_kind': operation_name})) + + self.fn_calls.append(SubstituteTemplate( + "\t\t\tinitialize_all_${operation_kind}_operations(manifest);", + {'operation_kind': operation_name})) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitInterfaceLibrary::__exit__") + + self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes)})) + + # Write out initialize_all method + self.top_level_file.write(SubstituteTemplate(self.top_level_initialize, + {'operation_count': self.operation_count, 'fn_calls':"\n".join(self.fn_calls)})) + + self.top_level_file.write(self.top_level_suffix) + self.top_level_file.close() + +################################################################################################### +################################################################################################### + +class Options: + def __init__(self): + pass + +################################################################################################### + +# +class Manifest: + + # + def __init__(self, args = None): + self.operations = {} + self.args = args + self.operation_count = 0 + self.operations_by_name = {} + + self.kernel_filter = '' + self.kernel_filter_list = [] + self.kernel_names = [] + self.operations_enabled = [] + self.selected_kernels = [] + self.ignore_kernel_names = [] + self.exclude_kernel_names = [] + self.compute_capabilities_baseline = [50,] + self.compute_capabilities_feature_set = ['50',] + self.curr_build_dir = '.' + self.filter_by_cc = True + + if self.args: + self.kernel_filter = self.args.kernels + self.curr_build_dir = args.curr_build_dir + + # A common user error is to use commas instead of semicolons. + if ',' in args.architectures: + raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures) + + self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',] + self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set)) + + if args.filter_by_cc in ['false', 'False', '0']: + self.filter_by_cc = False + + if args.operations == 'all': + self.operations_enabled = [] + else: + operations_list = [ + OperationKind.Gemm + , OperationKind.Conv2d + , OperationKind.Conv3d + , OperationKind.RankK + , OperationKind.Trmm + , OperationKind.Symm + ] + self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] + + if args.kernels == 'all': + self.kernel_names = [] + else: + self.kernel_names = [x for x in args.kernels.split(',') if x != ''] + + self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] + self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != ''] + + if args.kernel_filter_file is None: + self.kernel_filter_list = [] + else: + self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) + _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format( + filter_count = len(self.kernel_filter_list), + filter_file = args.kernel_filter_file)) + + self.operation_count = 0 + self.operations_by_name = {} + self.disable_full_archs_compilation = args.disable_full_archs_compilation + self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != '' + self.instantiation_level = 0 + try: + self.instantiation_level = int(args.instantiation_level) + except ValueError: + self.instantiation_level = 0 + + def add_kernel_filter(self, filter_str): + filter_re = re.compile(filter_str) + + self.kernel_filter_list.append(filter_re) + + def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): + # Non-negative integer which determines how many kernels are instantiated. + # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. + # increasing first digit reduces schedule / mixed type pruning, + # increasing second digit generates more cluster sizes, + # increasing third digit generates more MMA multipliers, + # increasing fourth digit generates more instruction shapes. + + if self.instantiation_level > 0: + return self.instantiation_level + + elif self.is_kernel_filter_set_to_all: + return exhaustive_level + + elif self.kernel_filter == '': + return pruned_level + + else: + return default_level + + + def get_kernel_filters(self, kernelListFile): + if os.path.isfile(kernelListFile): + with open(kernelListFile, 'r') as fileReader: + lines = [line.rstrip() for line in fileReader if not line.startswith("#")] + + lines = [re.compile(line) for line in lines if line] + return lines + else: + return [] + + # + def filter_out_kernels(self, kernel_name, kernel_filter_list): + + for kernel_filter_re in kernel_filter_list: + if kernel_filter_re.search(kernel_name) is not None: + return True + + return False + + + # + def _filter_string_matches(self, filter_string, haystack): + ''' Returns true if all substrings appear in the haystack in order''' + substrings = filter_string.split('*') + for sub in substrings: + idx = haystack.find(sub) + if idx < 0: + return False + haystack = haystack[idx + len(sub):] + return True + + # + def filter(self, operation): + ''' Filtering operations based on various criteria''' + + # filter based on compute capability + enabled = not (self.filter_by_cc) + + for cc in self.compute_capabilities_baseline: + + if cc >= operation.tile_description.minimum_compute_capability and \ + cc <= operation.tile_description.maximum_compute_capability and \ + (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)): + + enabled = True + break + + if not enabled: + return False + + if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: + return False + + name = operation.procedural_name() + + # eliminate duplicates + if name in self.operations_by_name.keys(): + return False + + # Filter based on list of valid substrings + if len(self.kernel_names): + enabled = False + + # compare against the include list + for name_substr in self.kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.") + enabled = True + break + else: + _LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.") + + # compare against the exclude list + for name_substr in self.ignore_kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.") + enabled = False + break + else: + _LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.") + + if len(self.kernel_filter_list) > 0: + if self.filter_out_kernels(name, self.kernel_filter_list): + _LOGGER.debug(f"Kernel {name} matched via kernel filter file.") + enabled = True + else: + _LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.") + enabled = False + + # CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect + # if CUTLASS_LIBRARY_KERNELS was specified. + # Changing that would break backwards compatibility. + # Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS, + # that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified. + for name_substr in self.exclude_kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.") + enabled = False + break + else: + _LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.") + + # TODO: filter based on compute data type + return enabled + # + + # + def append(self, operation): + ''' + Inserts the operation. + + operation_kind -> configuration_name -> [] + ''' + + if self.filter(operation): + + self.selected_kernels.append(operation.procedural_name()) + + self.operations_by_name[operation.procedural_name()] = operation + + # add the configuration + configuration_name = operation.configuration_name() + + # Split operations by minimum CC + min_cc = operation.arch + + if operation.operation_kind not in self.operations.keys(): + self.operations[operation.operation_kind] = {} + + if min_cc not in self.operations[operation.operation_kind]: + self.operations[operation.operation_kind][min_cc] = {} + + if configuration_name not in self.operations[operation.operation_kind][min_cc].keys(): + self.operations[operation.operation_kind][min_cc][configuration_name] = [] + + self.operations[operation.operation_kind][min_cc][configuration_name].append(operation) + self.operation_count += 1 + else: + _LOGGER.debug("Culled {} from manifest".format(operation.procedural_name())) + # + + def emit_manifest_cmake(self, manifest_path, top_level_path, source_files): + with open(manifest_path, "w") as manifest_file: + + target_text = SubstituteTemplate("""cutlass_target_sources(cutlass_library_objs PRIVATE + """, { }) + manifest_file.write(target_text + '\n\n') + manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/'))) + generated_path = os.path.join(self.curr_build_dir, 'generated') + for kind in self.operations.keys(): + kind_str = OperationKindNames[kind] + all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/') + manifest_file.write(f" {all_kind_file}\n") + manifest_file.write(')\n\n') + + for kind in self.operations.keys(): + for min_cc in sorted(self.operations[kind].keys()): + for subclass in sorted(source_files[kind][min_cc].keys()): + target_text = SubstituteTemplate("""cutlass_add_cutlass_library( + SUFFIX ${kind}_sm${min_cc}_${subclass} +""", { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass }) + manifest_file.write(target_text + '\n\n') + + for source_file in source_files[kind][min_cc][subclass]: + manifest_file.write(" %s\n" % str(source_file.replace('\\', '/'))) + + manifest_file.write(")\n") + + if self.disable_full_archs_compilation: + self.emit_disable_full_archs_compilation(manifest_file, source_files) + + def emit_disable_full_archs_compilation(manifest_file, source_files): + def for_hopper(name): + pass + + def for_ampere(name): + return "16816" in name or \ + "16832" in name or \ + "16864" in name or \ + ("1688" in name and "tf32" in name) + + def for_turing(name): + return ("1688" in name and "tf32" not in name) or \ + "8816" in name + + def for_volta(name): + return "884" in name + + def is_cpp(name): + return name.endswith(".cpp") + + def get_src_archs_str_given_requested_cuda_archs(archs, source_file): + intersected_archs = archs & set(self.compute_capabilities_baseline) + if intersected_archs == set(): + raise RuntimeError( + """ + Empty archs set for file {} after taking + the intersection of {} (global requested archs) and + {} (per file requested archs) + """.format(source_file, set(self.compute_capabilities_baseline), archs)) + else: + return " ".join(map(str, intersected_archs)) + + for min_cc in sorted(source_files.keys()): + for source_file in source_files[min_cc]: + if is_cpp(source_file): + continue # skip because source is cpp + elif for_ampere(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({80, 87, 90}, source_file) + elif for_turing(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({75}, source_file) + elif for_volta(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file) + else: + raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file)) + + manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str)) + + # + def emit(self, target = GeneratorTarget.Library): + + operation_emitters = { + GeneratorTarget.Library: EmitOperationKindLibrary + } + + # Emitters for all operations that fall under a particular kind (e.g., GEMM, Conv2d) + kind_emitters = { + GeneratorTarget.Library: EmitOperationKindAll + } + + interface_emitters = { + GeneratorTarget.Library: EmitInterfaceLibrary + } + + generated_path = os.path.join(self.curr_build_dir, 'generated') + + # create generated/ + if os.path.exists(generated_path): + shutil.rmtree(generated_path) + + os.mkdir(generated_path) + + with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter: + top_level_path = iface_emitter.top_level_path + for operation_kind in self.operations.keys(): + iface_emitter.emit(OperationKindNames[operation_kind]) + + source_files = {} + for kind in self.operations.keys(): + source_files[kind] = {} + for min_cc in self.operations[kind].keys(): + source_files[kind][min_cc] = {} + + for operation_kind, ops in self.operations.items(): + for min_cc, configurations in sorted(ops.items()): + with operation_emitters[target](generated_path, min_cc, operation_kind, self.args) as operation_kind_emitter: + for configuration_name, operations in configurations.items(): + _LOGGER.info(f"Emitting {configuration_name} with {len(operations)} operation{'' if len(operations) == 1 else 's'}.") + operation_kind_emitter.emit(configuration_name, operations) + + for subclass, files in operation_kind_emitter.source_files.items(): + if subclass not in source_files[operation_kind][min_cc]: + source_files[operation_kind][min_cc][subclass] = [] + source_files[operation_kind][min_cc][subclass].extend(operation_kind_emitter.source_files[subclass]) + + # Emit top level all_{gemm, conv2d, ...}_operations.cu files + with kind_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: + operation_kind_emitter.emit(ops) + + # write the manifest.cmake file containing paths from all targets + manifest_path = os.path.join(generated_path, "manifest.cmake") + + self.emit_manifest_cmake(manifest_path, top_level_path, source_files) + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..29ef056f26f914a9c3c33e13900c33642ad2f1b7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py @@ -0,0 +1,438 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Rank2K kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Rank K update operation +# +################################################################################################### + +# +class Rank2KOperation: + # + def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.Rank2K + self.arch = arch + self.tile_description = tile_description + self.rank_k_kind = rank_k_kind + # tensor A and B have same data type and layout + self.A = A + self.B = A + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'syr2k' if self.blas_mode == BlasMode.symmetric else 'her2k' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.C.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.A.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitRank2KUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.rank_k_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Rank2K< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.rank_k_complex_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Rank2K< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${transform_a}, + ${transform_b}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'fill_mode': FillModeTag[operation.C.fill_mode], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template + + return SubstituteTemplate(rank_k_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitRank2KConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + RankKKind.Universal: EmitRank2KUniversalInstance, + } + + self.rank_k_kind_wrappers = { + RankKKind.Universal: 'Rank2KOperation', + } + + self.instance_template = { + RankKKind.Universal: """ +${compile_guard_start} + manifest.append(new ${rank_k_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by rank_2k_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "rank_2k_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.rank_k_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..9841952332a170d6f401dbe34a0093540c166fb8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py @@ -0,0 +1,427 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting RankK kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Rank K update operation +# +################################################################################################### + +# +class RankKOperation: + # + def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.RankK + self.arch = arch + self.tile_description = tile_description + self.rank_k_kind = rank_k_kind + self.A = A + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return False + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'syrk' if self.blas_mode == BlasMode.symmetric else 'herk' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.C.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.A.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitRankKUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.rank_k_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::RankK< + ${element_a}, ${layout_a}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.rank_k_complex_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::RankK< + ${element_a}, ${layout_a}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${split_k_serial}, + ${math_operation}, + ${transform_a}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'fill_mode': FillModeTag[operation.C.fill_mode], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template + + return SubstituteTemplate(rank_k_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitRankKConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + RankKKind.Universal: EmitRankKUniversalInstance, + } + + self.rank_k_kind_wrappers = { + RankKKind.Universal: 'RankKOperation', + } + + self.instance_template = { + RankKKind.Universal: """ +${compile_guard_start} + manifest.append(new ${rank_k_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by rank_k_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "rank_k_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.rank_k_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..32e4376513679f06dc085ead068e258b3d8b5e72 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py @@ -0,0 +1,342 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid tcgen05 shapes and cluster sizes for SM100, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (tcgen05 shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. +""" + +try: + from .library import DynamicClusterShape +except: + from library import DynamicClusterShape + +SM100_CLUSTER_SHAPES_1SM = { + tuple(DynamicClusterShape) : 0, + # size 1 cluster + (1, 1, 1): 1, + # size 2 cluster + (1, 2, 1): 2, + (2, 1, 1): 5, + # size 4 clusters + (2, 2, 1): 6, + (1, 4, 1): 3, + (4, 1, 1): 6, + # size 8 clusters + (2, 4, 1): 7, + (4, 2, 1): 7, + (1, 8, 1): 8, + (8, 1, 1): 8, + # size 16 cluster + (4, 4, 1): 4, +} + +SM100_CLUSTER_SHAPES_2SM = { + tuple(DynamicClusterShape) : 0, + # size 2 cluster + (2, 1, 1): 1, + # size 4 clusters + (2, 2, 1): 2, + (4, 1, 1): 2, + # size 8 clusters + (2, 4, 1): 3, + (4, 2, 1): 3, + (8, 1, 1): 6, + # size 16 cluster + (4, 4, 1): 4, +} + +# MMA shapes + +# 16b Dense + +SM100_MMA_SHAPES_16b_DENSE_1SM = { + (64, 8, 16): 5, + (64, 16, 16): 2, + (64, 24, 16): 5, + (64, 32, 16): 2, + (64, 40, 16): 5, + (64, 48, 16): 5, + (64, 56, 16): 5, + (64, 64, 16): 2, + (64, 72, 16): 5, + (64, 80, 16): 5, + (64, 88, 16): 5, + (64, 96, 16): 5, + (64, 104, 16): 5, + (64, 112, 16): 5, + (64, 120, 16): 5, + (64, 128, 16): 0, + (64, 136, 16): 5, + (64, 144, 16): 5, + (64, 152, 16): 5, + (64, 160, 16): 5, + (64, 168, 16): 5, + (64, 176, 16): 5, + (64, 184, 16): 5, + (64, 192, 16): 3, + (64, 200, 16): 5, + (64, 208, 16): 5, + (64, 216, 16): 5, + (64, 224, 16): 5, + (64, 232, 16): 5, + (64, 240, 16): 5, + (64, 248, 16): 5, + (64, 256, 16): 3, + + (128, 16, 16): 2, + (128, 32, 16): 2, + (128, 48, 16): 5, + (128, 64, 16): 2, + (128, 80, 16): 5, + (128, 96, 16): 5, + (128, 112, 16): 5, + (128, 128, 16): 0, + (128, 144, 16): 5, + (128, 160, 16): 5, + (128, 176, 16): 5, + (128, 192, 16): 3, + (128, 208, 16): 5, + (128, 224, 16): 5, + (128, 240, 16): 5, + (128, 256, 16): 0, + +} + + +SM100_MMA_SHAPES_16b_DENSE_2SM = { + (128, 32, 16): 2, + (128, 64, 16): 2, + (128, 96, 16): 5, + (128, 128, 16): 0, + (128, 160, 16): 5, + (128, 192, 16): 5, + (128, 224, 16): 5, + (128, 256, 16): 0, + + (256, 32, 16): 2, + (256, 64, 16): 2, + (256, 96, 16): 5, + (256, 128, 16): 0, + (256, 160, 16): 5, + (256, 192, 16): 3, + (256, 224, 16): 5, + (256, 256, 16): 0, +} + +# TF32 Dense + +SM100_MMA_SHAPES_TF32_DENSE_1SM = { + (64, 8, 8): 5, + (64, 16, 8): 2, + (64, 24, 8): 5, + (64, 32, 8): 2, + (64, 40, 8): 5, + (64, 48, 8): 5, + (64, 56, 8): 5, + (64, 64, 8): 1, + (64, 72, 8): 5, + (64, 80, 8): 5, + (64, 88, 8): 5, + (64, 96, 8): 5, + (64, 104, 8): 5, + (64, 112, 8): 5, + (64, 120, 8): 5, + (64, 128, 8): 0, + (64, 136, 8): 5, + (64, 144, 8): 5, + (64, 152, 8): 5, + (64, 160, 8): 5, + (64, 168, 8): 5, + (64, 176, 8): 5, + (64, 184, 8): 5, + (64, 192, 8): 3, + (64, 200, 8): 5, + (64, 208, 8): 5, + (64, 216, 8): 5, + (64, 224, 8): 5, + (64, 232, 8): 5, + (64, 240, 8): 5, + (64, 248, 8): 5, + (64, 256, 8): 3, + + (128, 16, 8): 2, + (128, 32, 8): 2, + (128, 48, 8): 5, + (128, 64, 8): 2, + (128, 80, 8): 5, + (128, 96, 8): 5, + (128, 112, 8): 5, + (128, 128, 8): 0, + (128, 144, 8): 5, + (128, 160, 8): 5, + (128, 176, 8): 5, + (128, 192, 8): 3, + (128, 208, 8): 5, + (128, 224, 8): 5, + (128, 240, 8): 5, + (128, 256, 8): 0, + +} + +SM100_MMA_SHAPES_TF32_DENSE_2SM = { + (128, 32, 8): 2, + (128, 64, 8): 1, + (128, 96, 8): 5, + (128, 128, 8): 0, + (128, 160, 8): 5, + (128, 192, 8): 5, + (128, 224, 8): 5, + (128, 256, 8): 0, + + (256, 32, 8): 2, + (256, 64, 8): 1, + (256, 96, 8): 5, + (256, 128, 8): 0, + (256, 160, 8): 5, + (256, 192, 8): 5, + (256, 224, 8): 5, + (256, 256, 8): 0, +} + +# F8F6F4 +SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = { + (64, 8, 32): 4, + (64, 16, 32): 4, + (64, 24, 32): 5, + (64, 32, 32): 3, + (64, 40, 32): 5, + (64, 48, 32): 5, + (64, 56, 32): 5, + (64, 64, 32): 2, + (64, 72, 32): 5, + (64, 80, 32): 5, + (64, 88, 32): 5, + (64, 96, 32): 5, + (64, 104, 32): 5, + (64, 112, 32): 5, + (64, 120, 32): 5, + (64, 128, 32): 0, + (64, 136, 32): 5, + (64, 144, 32): 5, + (64, 152, 32): 5, + (64, 160, 32): 5, + (64, 168, 32): 5, + (64, 176, 32): 5, + (64, 184, 32): 5, + (64, 192, 32): 5, + (64, 200, 32): 5, + (64, 208, 32): 5, + (64, 216, 32): 5, + (64, 224, 32): 5, + (64, 232, 32): 5, + (64, 240, 32): 5, + (64, 248, 32): 5, + (64, 256, 32): 0, + + (128, 16, 32): 4, + (128, 32, 32): 3, + (128, 48, 32): 5, + (128, 64, 32): 2, + (128, 80, 32): 5, + (128, 96, 32): 5, + (128, 112, 32): 5, + (128, 128, 32): 0, + (128, 144, 32): 5, + (128, 160, 32): 5, + (128, 176, 32): 5, + (128, 192, 32): 5, + (128, 208, 32): 5, + (128, 224, 32): 5, + (128, 240, 32): 5, + (128, 256, 32): 0, + +} + +SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = { + (128, 32, 32): 3, + (128, 64, 32): 2, + (128, 96, 32): 5, + (128, 128, 32): 1, + (128, 160, 32): 5, + (128, 192, 32): 5, + (128, 224, 32): 5, + (128, 256, 32): 1, + + (256, 32, 32): 2, + (256, 64, 32): 2, + (256, 96, 32): 5, + (256, 128, 32): 0, + (256, 160, 32): 5, + (256, 192, 32): 5, + (256, 224, 32): 5, + (256, 256, 32): 0, +} + +# MXF8F6F4 +SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = { + (128, 64, 32): 1, + (128, 128, 32): 0, + (128, 192, 32): 1, + (128, 256, 32): 0, +} + + +SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { + (256, 64, 32): 1, + (256, 128, 32): 0, + (256, 192, 32): 1, + (256, 256, 32): 0, + + +} + +# MXF4NVF4 +SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { + (128, 64, 64): 1, + (128, 128, 64): 0, + (128, 192, 64): 1, + (128, 256, 64): 0, +} + +SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { + # Multiples of 16 for N + (256, 64, 64): 1, + (256, 128, 64): 0, + (256, 192, 64): 1, + (256, 256, 64): 0, + +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf24fe7f528020be4dcfc6ac41cfe949dd63be5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM100 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List, Union, Callable + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_tcgen05_level_from_global_level(global_level: int): + return global_level % 10 + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm100_shapes import * +except: + from sm100_shapes import * + +########### + +def generate_tf32_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_16b_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + + +def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + if enable_runtime_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + return math_instructions_1sm, math_instructions_2sm + + +def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None): + """ + Generate all cluster shapes for SM100 at or above the given level. + + Args: + level: The global level to generate cluster shapes for. + + Returns: + A tuple of two lists of cluster shapes. + The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM. + """ + cluster_level = get_cluster_level_from_global_level(level) + + assert cluster_level >= 4 + + if change_priority_func is not None: + SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM) + SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM) + change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY) + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm + + else: + + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..e14761aae6494f877e6dc6521b30baea0db7509c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py @@ -0,0 +1,212 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid WGMMA shapes, MMA multipliers, and cluster sizes for SM90, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (MMA multiplier, WGMMA shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. The default behavior in `generator.py` is that level 1 is only emitted +when the `--kernel` argument is non-empty. +""" + +# NOTE: more combinations are possible here. +# Levels [0, 3] exist in order to control exactly what configs are generated in different dtypes. +# The rest are only used in the exhaustive mode (when the corresponding level digit is 9). +# MMA multipliers are multiplied by MMA instruction shapes (WGMMA shapes) to get CTA shapes. +SM90_MMA_MULTIPLIERS = { + (2, 1, 4): 0, + (1, 1, 4): 1, + (4, 1, 4): 2, + (2, 2, 4): 3, + (2, 1, 8): 4, + (4, 1, 8): 4, + (1, 1, 8): 4, + (2, 2, 8): 4, + (2, 1, 16): 5, + (4, 1, 16): 5, + (1, 1, 16): 5, + (2, 2, 16): 5, +} + +# Level 0: only (1, 2, 1) -- fp8 dense gemms in pruned case +# Level 1: clusters with 2 CTAs -- all but fp8 (s8, u8, f16, b16, f32, tf32) dense gemms in pruned case +# Level 2: clusters with 1 or 2 CTAs +# Level 3: clusters with 1, 2, or 4 CTAs +# Level 4: clusters with 1, 2, 4, or 8 CTAs +# Level 5: clusters with 1, 2, 4, 8, or 16 CTAs +SM90_CLUSTER_SIZES = { + (1, 2, 1): 0, + (2, 1, 1): 1, + (1, 1, 1): 2, + (2, 2, 1): 3, + (1, 4, 1): 3, + (4, 1, 1): 3, + (2, 4, 1): 4, + (4, 2, 1): 4, + (1, 8, 1): 4, + (8, 1, 1): 4, + (4, 4, 1): 5, +} + + +# WGMMA shapes +# Level 0: "default" shape only, +# Level 1: additional shapes for the unpruned case (tf32 only) +# Level 2: shapes that are all powers of 2 +# Level 3: all other shapes +SM90_WGMMA_SHAPES_FP16_BF16_DENSE = { + (64, 8, 16): 2, + (64, 16, 16): 2, + (64, 24, 16): 3, + (64, 32, 16): 2, + (64, 40, 16): 3, + (64, 48, 16): 3, + (64, 56, 16): 3, + (64, 64, 16): 2, + (64, 72, 16): 3, + (64, 80, 16): 3, + (64, 88, 16): 3, + (64, 96, 16): 3, + (64, 104, 16): 3, + (64, 112, 16): 3, + (64, 120, 16): 3, + (64, 128, 16): 0, + (64, 136, 16): 3, + (64, 144, 16): 3, + (64, 152, 16): 3, + (64, 160, 16): 3, + (64, 168, 16): 3, + (64, 176, 16): 3, + (64, 184, 16): 3, + (64, 192, 16): 3, + (64, 200, 16): 3, + (64, 208, 16): 3, + (64, 216, 16): 3, + (64, 224, 16): 3, + (64, 232, 16): 3, + (64, 240, 16): 3, + (64, 248, 16): 3, + (64, 256, 16): 1, +} + +SM90_WGMMA_SHAPES_TF32_DENSE = { + (64, 8, 8): 2, + (64, 16, 8): 2, + (64, 24, 8): 3, + (64, 32, 8): 2, + (64, 40, 8): 3, + (64, 48, 8): 3, + (64, 56, 8): 3, + (64, 64, 8): 2, + (64, 72, 8): 3, + (64, 80, 8): 3, + (64, 88, 8): 3, + (64, 96, 8): 3, + (64, 104, 8): 3, + (64, 112, 8): 3, + (64, 120, 8): 3, + (64, 128, 8): 0, + (64, 136, 8): 3, + (64, 144, 8): 3, + (64, 152, 8): 3, + (64, 160, 8): 3, + (64, 168, 8): 3, + (64, 176, 8): 3, + (64, 184, 8): 3, + (64, 192, 8): 3, + (64, 200, 8): 3, + (64, 208, 8): 3, + (64, 216, 8): 3, + (64, 224, 8): 3, + (64, 232, 8): 3, + (64, 240, 8): 3, + (64, 248, 8): 3, + (64, 256, 8): 1, +} + +SM90_WGMMA_SHAPES_FP8_DENSE = { + (64, 8, 32): 2, + (64, 16, 32): 2, + (64, 24, 32): 3, + (64, 32, 32): 2, + (64, 40, 32): 3, + (64, 48, 32): 3, + (64, 56, 32): 3, + (64, 64, 32): 2, + (64, 72, 32): 3, + (64, 80, 32): 3, + (64, 88, 32): 3, + (64, 96, 32): 3, + (64, 104, 32): 3, + (64, 112, 32): 3, + (64, 120, 32): 3, + (64, 128, 32): 0, + (64, 136, 32): 3, + (64, 144, 32): 3, + (64, 152, 32): 3, + (64, 160, 32): 3, + (64, 168, 32): 3, + (64, 176, 32): 3, + (64, 184, 32): 3, + (64, 192, 32): 3, + (64, 200, 32): 3, + (64, 208, 32): 3, + (64, 216, 32): 3, + (64, 224, 32): 3, + (64, 232, 32): 3, + (64, 240, 32): 3, + (64, 248, 32): 3, + (64, 256, 32): 1, +} + +SM90_WGMMA_SHAPES_INT8_DENSE = { + (64, 8, 32): 2, + (64, 16, 32): 2, + (64, 24, 32): 3, + (64, 32, 32): 2, + (64, 48, 32): 3, + (64, 64, 32): 2, + (64, 80, 32): 3, + (64, 96, 32): 3, + (64, 112, 32): 3, + (64, 128, 32): 0, + (64, 144, 32): 3, + (64, 160, 32): 3, + (64, 176, 32): 3, + (64, 192, 32): 3, + (64, 208, 32): 3, + (64, 224, 32): 3, + (64, 240, 32): 3, + (64, 256, 32): 1, +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5fdf14abb85835f71ecfd704a2738f5792af50 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py @@ -0,0 +1,753 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM90 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +# NOTE: this is a duplicate of CudaToolkitVersionSatisfies in generator.py +def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): + + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] + + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_wgmma_level_from_global_level(global_level: int): + return global_level % 10 + + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm90_shapes import ( + SM90_MMA_MULTIPLIERS, + SM90_CLUSTER_SIZES, + SM90_WGMMA_SHAPES_TF32_DENSE, + SM90_WGMMA_SHAPES_FP16_BF16_DENSE, + SM90_WGMMA_SHAPES_FP8_DENSE, + SM90_WGMMA_SHAPES_INT8_DENSE, + ) +except: + from sm90_shapes import ( + SM90_MMA_MULTIPLIERS, + SM90_CLUSTER_SIZES, + SM90_WGMMA_SHAPES_TF32_DENSE, + SM90_WGMMA_SHAPES_FP16_BF16_DENSE, + SM90_WGMMA_SHAPES_FP8_DENSE, + SM90_WGMMA_SHAPES_INT8_DENSE, + ) + + +def generate_tf32_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_TF32_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_fp16_bf16_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP16_BF16_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_fp8_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP8_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_int8_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_INT8_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType): + # DataTypeSize are in the unit of bits + a_bytes = DataTypeSize[a_type] // 8 + b_bytes = DataTypeSize[b_type] // 8 + if a_bytes == 4 or b_bytes == 4: + return generate_tf32_math_instruction_shapes_sm90(wgmma_level) + elif a_bytes == 2 or b_bytes == 2: + return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level) + else: + return generate_fp8_math_instruction_shapes_sm90(wgmma_level) + +########### + +def generate_tf32_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_tf32_math_instruction_shapes_sm90(wgmma_level): + math_instructions.append( + MathInstruction( + math_instruction_shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + return math_instructions + +def generate_fp16_bf16_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def generate_fp8_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_fp8_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for a_type, b_type, acc_type in types_of_a_b_acc: + math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type) + for math_instruction_shape in math_instruction_shapes: + math_instructions += [ + MathInstruction( + math_instruction_shape, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ), + ] + return math_instructions + +def generate_int8_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_int8_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.u8, DataType.u8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def make_sparse_math_instructions(math_instructions): + sparse_instructions = [] + for inst in math_instructions: + if inst.opcode_class == OpcodeClass.TensorOp: + sparse_instructions.append(MathInstruction( + (inst.instruction_shape[0], inst.instruction_shape[1], inst.instruction_shape[2] * 2), + inst.element_a, inst.element_b, inst.element_accumulator, + OpcodeClass.SparseTensorOp, + inst.math_operation),) + return sparse_instructions + + +#### Step 2: generate tile descriptions from math instruction shapes + +def is_tile_desc_valid(tile_description): + if tile_description.minimum_compute_capability != 90 or tile_description.maximum_compute_capability != 90: + return False + + element_a, element_b, element_accum = ( + tile_description.math_instruction.element_a, + tile_description.math_instruction.element_b, + tile_description.math_instruction.element_accumulator + ) + + cluster_size, cta_shape = ( + tile_description.cluster_shape, + tile_description.threadblock_shape, + ) + grid_size = ( + cta_shape[0] * cluster_size[0] + + cta_shape[1] * cluster_size[1] + + cta_shape[2] * cluster_size[2] + ) + num_ctas_in_cluster = cluster_size[0] * cluster_size[1] * cluster_size[2] + cluster_shape = ( + cluster_size[0] * cta_shape[0], + cluster_size[1] * cta_shape[1], + cluster_size[2] * cta_shape[2] + ) + + FP32_TYPES = [DataType.f32, DataType.tf32] + FP16_TYPES = [DataType.f16, DataType.bf16] + is_fp32 = element_a in FP32_TYPES and element_b in FP32_TYPES + is_fp16 = element_a in FP16_TYPES and element_b in FP16_TYPES + + # Maximum number of CTAs per cluster is 8 for Hopper, but up to 16 is + # allowed for non portable clusters. + if num_ctas_in_cluster > 16 or num_ctas_in_cluster < 1: + return False + + if grid_size < 1: + return False + + # SM90 WGMMA shapes are always 64 across M, therefore + # CTA shape across M must always be a multiple of 64. + if cta_shape[0] < 64 or cta_shape[0] % 64 != 0: + return False + + # The minimum WGMMA shape across N is 8, and increments + # vary across different dtypes, but they're never smaller + # than 8. The minimum CTA shape allowed across N though is 16. + if cta_shape[1] < 16 or cta_shape[1] % 8 != 0: + return False + + # SM90 WGMMA shapes across K are always 8 for 32 bit dense + # operations, 16 for 16 bit, and 32 for 8 bit. In any case, + # the CTA shape across K should be a multiple of 8 and at least + # twice the WGMMA shape across K. + if cta_shape[2] < 16 or cta_shape[2] % 8 != 0: + return False + + # Minimum of 2 stages (very rough heuristic that may filter out valid kernel configs) + if (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 256: + return False + + if is_fp32 and (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 128: + return False + + if is_fp32 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 64: + return False + + if is_fp16 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 128: + return False + + # CTA shape upper bound: <256, 256, 256> + if cta_shape[0] > 256 or cta_shape[1] > 256 or cta_shape[2] > 256: + return False + + return True + +def get_mma_multipliers(level: int): + assert isinstance(level, int) and level >= 0 + mma_level = get_mma_level_from_global_level(level) + return [ + mma_mul for mma_mul, mma_min_level in SM90_MMA_MULTIPLIERS.items() if mma_level >= mma_min_level + ] + +def get_cluster_sizes(level: int, is_aligned: bool): + if not is_aligned: + return [(1, 1, 1)] + assert isinstance(level, int) and level >= 0 + cluster_level = get_cluster_level_from_global_level(level) + return [ + cluster_size for cluster_size, cluster_min_level in SM90_CLUSTER_SIZES.items() if cluster_level >= cluster_min_level + ] + +def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: int): + tile_descriptions = set() + mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned) + for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes): + + # generator can stamp out duplicate kernels, because it doesn't explicitly set instruction + # shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using + # the auto kernel schedule. + + math_inst_stub = copy.deepcopy(math_inst) + math_inst_stub.instruction_shape = [0, 0, 0] + + tile_desc = TileDescription( + threadblock_shape=[ + math_inst.instruction_shape[0] * mma_mul[0], + math_inst.instruction_shape[1] * mma_mul[1], + math_inst.instruction_shape[2] * mma_mul[2] + ], + stages=0, + warp_count=[4, 1, 1], + math_instruction=math_inst_stub, + min_compute=90, + max_compute=90, + cluster_shape=cluster_size) + # For sparse kernels K-tile is twice as large (due to 2x MMA-K size) + # Reduce it to same size as dense to afford more smem stages + if math_inst.opcode_class == OpcodeClass.SparseTensorOp: + tile_desc.threadblock_shape[2] = tile_desc.threadblock_shape[2] // 2 + if is_tile_desc_valid(tile_desc): + tile_descriptions.add(tile_desc) + + return tile_descriptions + +#### Step 3: map tile description to valid schedules + +def is_tile_desc_compatible_with_cooperative(tile_description): + # Cooperative kernels require a minimum CTA-M of 128 + return tile_description.threadblock_shape[0] % 128 == 0 + + +def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types): + dtype_a, dtype_b, dtype_c, dtype_d, dtype_acc, dtype_epi = ( + data_types["a_type"], + data_types["b_type"], + data_types["c_type"], + data_types["d_type"], + data_types["acc_type"], + data_types["epi_type"] + ) + mn = tile_description.threadblock_shape[0] * tile_description.threadblock_shape[1] + bitsize_c, bitsize_d = DataTypeSize[dtype_c], DataTypeSize[dtype_d] + + shmem_bits_c, shmem_bits_d = bitsize_c * mn, bitsize_d * mn + shmem_bits_total = shmem_bits_c + shmem_bits_d + # Magic number: 2^20 + # Existing logic suggested that tile shape 256x128 (or 128x256) + # would run out of shmem if D is FP32, and source is needed. + # That would be 256 * 128 * 32 == 2^21 (~262 KB), which is over the limit. + # Hopper's max shmem size is 228 KB, and 2^20 ~= 131 KB. + # Since epilogue can't possibly use ALL of the shmem available + # we can just settle on 2^20 bits (~ 131 KB) being the upper bound + # we would allow for epilogue. + # This can be different for non-persistent kernels where epilogue and + # mainloop shmem is shared. + if shmem_bits_total > 2 ** 20: + return False + + return True + + +def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout, + instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x): + # Level 0: prune according to existing generator.py behavior + # Level >= 1: no pruning + level = get_pruning_level_from_global_level(instantiation_level) + schedules = [] + stream_k_schedules = [] + + if not is_tile_desc_valid(tile_description): + return schedules, stream_k_schedules + + FP16_TYPES = [DataType.f16, DataType.bf16] + is_fp16 = data_types["a_type"] in FP16_TYPES and data_types["b_type"] in FP16_TYPES + + FP8_TYPES = [DataType.e4m3, DataType.e5m2] + is_fp8 = data_types["a_type"] in FP8_TYPES and data_types["b_type"] in FP8_TYPES + can_do_fp8_fast_accum = is_fp8 and enable_fp8_fast_acc + + FP32_TYPES = [DataType.f32, DataType.tf32] + is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES + requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor + + can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description) + can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types) + + default_epilogue = EpilogueScheduleType.NoSmemWarpSpecialized if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed + auto_epilogue = EpilogueScheduleType.ScheduleAuto if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed + + cta_m, cta_n, cta_k = ( + tile_description.threadblock_shape[0], + tile_description.threadblock_shape[1], + tile_description.threadblock_shape[2] + ) + c_type = data_types["c_type"] + d_type = data_types["d_type"] + is_void_c = c_type == DataType.void + + # Filter out invalid kernels + is_nt = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.RowMajor + is_tn = layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.ColumnMajor + is_nn = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.ColumnMajor + + # static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + # "Copy size must evenly divide SMEM tile."); + if is_fp32 and is_nt and (cta_n % cta_k != 0): + return [], [] + + # static_assert(!TransposeB || (cutlass::bits_to_bytes((size<1>(SmemLayoutB{}) * sizeof_bits::value))) == 128, + # "SmemLayoutB K must be 128bytes to be transposed.") + if is_fp32 and is_nt and cta_k != 32: + return [], [] + + # Static assert failure when instantiating SmemLayoutB + if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0): + return [], [] + + grouped = is_grouped(gemm_kind) + if grouped: + # the following cases are unsupported by grouped GEMM + if not is_aligned: + return [], [] + if requires_transposed_epilogue: + return [], [] + + # Early pruning + if level < 1: + # Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64 + if is_fp16 and cta_m <= 64 and cta_n <= 128 and cta_k <= 64: + return [], [] + + # FP8 configs with CTA tile larger than or equal to 256x128x128 limit data types and schedules + is_large_fp8_tile = is_fp8 and cta_m >= 256 and cta_n >= 128 and cta_k >= 128 + if is_large_fp8_tile: + # Only void-C, and only FP8 outputs allowed + if not is_void_c or d_type not in FP8_TYPES: + return [], [] + if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue: + schedules = [] + if is_blockwise(gemm_kind): + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + else: + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + return schedules, [] + return [], [] + + if is_fp8 and not is_large_fp8_tile: + valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void] + # Prune all configs with fp8 source, and all configs with non-fp8 output + # that have different dtypes for source and output. + if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type): + return [], [] + + # FP32/TF32 kernels don't stamp out void-C + if is_fp32 and is_void_c: + return [], [] + + # Void-c only makes a difference for TMA epilogues + if is_void_c and not can_do_tma_epilogue: + return [], [] + + # For mixed input data types + a_type_size = DataTypeSize[data_types["a_type"]] + b_type_size = DataTypeSize[data_types["b_type"]] + if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1): + schedules = [] + stream_k_schedules = [] + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + + if not is_blockwise(gemm_kind): + schedules.append([ + KernelScheduleType.TmaWarpSpecialized, + epilogue_schedule + ]) + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule + ]) + if cta_m >= 128: + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + else: + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative + if is_blockwise(gemm_kind): + schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + epilogue_schedule + ]) + else: + schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule + ]) + return schedules, stream_k_schedules + + if not is_aligned and not is_blockwise(gemm_kind): + schedules = [[KernelScheduleType.CpAsyncWarpSpecialized, + default_epilogue]] + stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative: + schedules.append([ + KernelScheduleType.CpAsyncWarpSpecializedCooperative, + default_epilogue + ]) + stream_k_schedules.append([ + KernelScheduleType.CpAsyncWarpSpecializedCooperative, + default_epilogue + ]) + + return schedules, stream_k_schedules + + schedules = [] + # Pruning: emit Void-C and Grouped kernels with persistent kernels only + if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind): + # Pruning: don't stamp out fp8 kernels with auto schedule + if not is_fp8: + schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue]) + schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue]) + stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 0): + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + # Inconsistency: fp8 pingpong only gets stamped out with fast accum + if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) + ]) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue + if not is_fp8 or level >= 1: + if not is_blockwise(gemm_kind): + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + else: + schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + + if can_do_fp8_fast_accum: + if not grouped: + schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue]) + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)]) + + if can_do_cooperative: + if is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + default_epilogue + ]) + else: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + default_epilogue + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, + default_epilogue + ]) + + # persistent kernels with TMA epilogues + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + if can_do_cooperative: + if is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + else: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + # Grouped GEMM do not support Stream-K scheduler + if grouped: + return schedules, [] + return schedules, stream_k_schedules + + +#### Misc: helpers + +def generate_data_types_from_math_instruction(math_instruction, element_source = None, element_dest = None, element_epilogue = None): + element_a, element_b = math_instruction.element_a, math_instruction.element_b + element_accumulator = math_instruction.element_accumulator + element_c = element_source or element_accumulator + element_d = element_dest or element_accumulator + element_epilogue = element_epilogue or element_accumulator + data_types = { + "a_type" : element_a, + "b_type" : element_b, + "c_type" : element_c, + "d_type" : element_d, + "acc_type" : element_accumulator, + "epi_type" : element_epilogue + } + return data_types + +def fix_alignments(data_types, layout, alignment_bits = 128): + operand_keys = ["a_type", "b_type", "c_type"] + operands_to_fix = ["c_type"] + new_layout = [] + assert len(layout) == len(operand_keys) + for i, k in enumerate(operand_keys): + assert k in data_types and data_types[k] in DataTypeSize + dtype = data_types[k] + dtype_size_bits = DataTypeSize[dtype] + + layout_type = layout[i][0] + layout_alignment = layout[i][1] + + # Don't modify alignment if dtype's been changed to void + if k in operands_to_fix and dtype_size_bits >= 1: + layout_alignment = alignment_bits // dtype_size_bits + + new_layout.append([layout_type, layout_alignment]) + + return new_layout diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..8661ff798b2e3e0987fdf7e050b6ad2e0f8f3678 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py @@ -0,0 +1,440 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Symm kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Symm update operation +# +################################################################################################### + +# +class SymmOperation: + # + def __init__(self, symm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.Symm + self.arch = arch + self.tile_description = tile_description + self.symm_kind = symm_kind + # tensor A and B have same data type and layout + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'symm' if self.blas_mode == BlasMode.symmetric else 'hemm' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def side_mode_name(self): + return "%s" % (ShortSideModeNames[self.A.side_mode]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.A.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = self.C.alignment + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'side_mode': self.side_mode_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitSymmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.symm_template = """ +// Symm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Symm< + ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.symm_complex_template = """ +// Symm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Symm< + ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'side_mode': SideModeTag[operation.A.side_mode], + 'fill_mode': FillModeTag[operation.A.fill_mode], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + symm_template = self.symm_complex_template if operation.is_complex() else self.symm_template + + return SubstituteTemplate(symm_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitSymmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + SymmKind.Universal: EmitSymmUniversalInstance, + } + + self.symm_kind_wrappers = { + SymmKind.Universal: 'SymmOperation', + } + + self.instance_template = { + SymmKind.Universal: """ +${compile_guard_start} + manifest.append(new ${symm_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by symm_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "symm_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.symm_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.symm_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'symm_kind': self.symm_kind_wrappers[operation.symm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..46ba360cb615c955d329b390c0ab93d13ed88c7c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py @@ -0,0 +1,447 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Trmm kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a TRMM operation +# +################################################################################################### + +# +class TrmmOperation: + # + def __init__(self, trmm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8): + + self.operation_kind = OperationKind.Trmm + self.arch = arch + self.tile_description = tile_description + self.trmm_kind = trmm_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_planar_complex(self): +# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray) + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, TrmmKindNames[self.trmm_kind]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # + def side_mode_name(self): + return "%s" % (ShortSideModeNames[self.A.side_mode]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.A.fill_mode]) + + # + def diag_type_name(self): + return "%s" % (ShortDiagTypeNames[self.A.diag_type]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_${diag_type}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'side_mode': self.side_mode_name(), + 'fill_mode': self.fill_mode_name(), + 'diag_type': self.diag_type_name(), + 'alignment': "%d" % self.C.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitTrmmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.trmm_template = """ +// Trmm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Trmm< + ${element_a}, ${layout_a}, + ${side_mode}, ${fill_mode}, ${diag_type}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.trmm_complex_template = """ +// Trmm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Trmm< + ${element_a}, ${layout_a}, + ${side_mode}, ${fill_mode}, ${diag_type}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${transform_a} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'side_mode' : SideModeTag[operation.A.side_mode], + 'fill_mode': FillModeTag[operation.A.fill_mode], + 'diag_type' : DiagTypeTag[operation.A.diag_type], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(1), # TRMM A's alignment is always 1 for no padding to work until we make zfill work with variable bytes + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform] + } + + trmm_template = self.trmm_complex_template if operation.is_complex() else self.trmm_template + + return SubstituteTemplate(trmm_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitTrmmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + TrmmKind.Universal: EmitTrmmUniversalInstance, + } + + self.trmm_kind_wrappers = { + TrmmKind.Universal: 'TrmmOperation', + } + + self.instance_template = { + TrmmKind.Universal: """ +${compile_guard_start} + manifest.append(new ${trmm_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by trmm_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "trmm_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.trmm_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.trmm_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'trmm_kind': self.trmm_kind_wrappers[operation.trmm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..c396d75a5534493f1ebf90043f2a182eb46abb7f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py @@ -0,0 +1,132 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath('../../media/docs')) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'CUTLASS Python interface' +copyright = '2023, NVIDIA' +author = 'NVIDIA' +release = '3.1.0' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'myst_parser', + 'nbsphinx', + 'nbsphinx_link', + 'sphinx_copybutton', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosectionlabel', + 'sphinx.ext.autosummary', + 'sphinx.ext.coverage', + 'sphinx.ext.extlinks', + 'sphinx.ext.ifconfig', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_inline_tabs', + ] + +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} + +autodoc_typehints = 'description' + +pygments_style = "sphinx" +pygments_dark_style = "monokai" + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# Ignore errors when converting notebooks +nbsphinx_allow_errors = True + +language = 'en' +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_static_path = ['_static'] + +html_title = "CUTLASS Python" +html_baseurl = 'docs' +html_theme = 'furo' +html_theme_options = { + "light_logo": "cutlass-logo-small.png", + "dark_logo": "cutlass-logo-small.png", + "light_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "dark_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "footer_icons": [ + { + "name": "GitHub", + "url": "https://github.com/NVIDIA/cutlass", + "html": """ + + + + """, + "class": "", + }, + ], +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..308a5676b06f00089d1cdfe0fb83b442ca2df36e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py @@ -0,0 +1,36 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from .int_tuple import * +from .layout import * +from .swizzle import * +from .typing import * diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..3d722130c52142e68a3bcd54ac708012aeeeaad3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py @@ -0,0 +1,225 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Functions for manipulating IntTuples +""" + +from functools import reduce +from itertools import chain +from typing import Union +from .typing import Integer + + +def is_int(x): + return isinstance(x, Integer) + + +def is_tuple(x): + return isinstance(x, tuple) + + +def flatten(t): + if is_tuple(t): + if len(t) == 0: + return () + else: + return tuple(i for a in t for i in flatten(a)) + else: + return (t,) + + +def signum(a): + return bool(a > 0) - bool(a < 0) + + +def product(a): + if is_tuple(a): + return reduce(lambda val,elem : val*product(elem), a, 1) + else: + return a + + +def inner_product(a, b): + if is_tuple(a): # tuple tuple + assert len(a) == len(b) + return sum(inner_product(x,y) for x,y in zip(a,b)) + else: # "int" "int" + assert not is_tuple(b) + return a * b + + +def tuple_max(a): + if is_tuple(a): + return max(tuple_max(x) for x in a) + else: + return a + + +def elem_scale(a, b): + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(elem_scale(x,y) for x,y in zip(a,b)) + else: # tuple "int" + assert False # Error + else: + if is_tuple(b): # "int" tuple + return elem_scale(a, product(b)) + else: # "int" "int" + return a * b + + +# Inclusive prefix ceil div with output congruent to input a +def shape_div(a, b): + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(shape_div(x,y) for x,y in zip(a,b)) + else: # tuple "int" + #r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))] + r = [] + for v in a: + r.append(shape_div(v,b)) + b = shape_div(b,product(v)) + return tuple(r) + else: + if is_tuple(b): # "int" tuple + return shape_div(a, product(b)) + else: # "int" "int" + assert a % b == 0 or b % a == 0 + return (a + b - 1) // b + +# Exclusive prefix product with output congruent to input a +def prefix_product(a, init=1): + if is_tuple(a): + if is_tuple(init): # tuple tuple + assert len(a) == len(init) + return tuple(prefix_product(x,i) for x,i in zip(a,init)) + else: # tuple "int" + #r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))] + r = [] + for v in a: + r.append(prefix_product(v,init)) + init = init * product(v) + return tuple(r) + else: + if is_tuple(init): # "int" tuple + assert False # Error + else: # "int" "int" + return init + + +def idx2crd(idx, shape, stride=None): + if stride is None: + stride = prefix_product(shape) + + if is_tuple(idx): + if is_tuple(shape): # tuple tuple tuple + assert len(idx) == len(shape) and len(idx) == len(stride) + return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride)) + else: # tuple "int" "int" + assert False # Error + else: + if is_tuple(shape): # "int" tuple tuple + assert len(shape) == len(stride) + return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride)) + else: # "int" "int" "int" + return (idx // stride) % shape + + +def crd2idx(crd, shape, stride=None): + if stride is None: + stride = prefix_product(shape) + + if is_tuple(crd): + if is_tuple(shape): # tuple tuple tuple + assert len(crd) == len(shape) and len(crd) == len(stride) + return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride)) + else: # tuple "int" "int" + assert False, f"crd={crd}, shape={shape}" # Error + else: + if crd is None: + crd = 0 + + if is_tuple(shape): # "int" tuple tuple + assert len(shape) == len(stride) + result = 0 + for i in range(len(shape)-1): + result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) + crd = crd // product(shape[i]) + return result + crd2idx(crd, shape[-1], stride[-1]) + else: # "int" "int" "int" + return crd * stride + + +# Transform crd into the dst_shape's iteration space +def crd2crd(crd, dst_shape, src_shape=None): + if is_tuple(crd): + if is_tuple(dst_shape): # tuple tuple + assert len(crd) == len(dst_shape) + return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape)) + else: # tuple "int" + # Ambiguous unless we have src_shape + assert src_shape is not None + return crd2idx(crd, src_shape) + else: + if is_tuple(dst_shape): # "int" tuple + return idx2crd(crd, dst_shape) + else: # "int" "int" + assert crd < dst_shape + return crd + + +# Filter trg according to crd: keep only elements of trg that are paired with None +def slice_(crd: Union[None, tuple, int], + trg: Union[tuple, int]): + if is_tuple(crd): + if is_tuple(trg): # tuple tuple + assert len(crd) == len(trg) + # match C++ behavior of `filter_tuple` using `tuple_cat(...)` + return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)]))) + else: + assert False # tuple "int" : Error + elif crd is None: + # match C++ behavior `return cute::tuple{b};` + return (trg,) + else: + return () + + +# Determine if None appears at any of an int_tuples' terminals +def has_none(a: Union[None, tuple, int]): + if is_tuple(a): + return any(has_none(v) for v in a) + else: + return a is None diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..7c220eb16dd089c65fdbe6d6929b357ace0a77c1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py @@ -0,0 +1,367 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Definition of CuTe Layouts and functions to manipulate them +""" + +from itertools import chain +from typing import Union + +from .int_tuple import * + + +class LayoutBase: + pass + + +def is_layout(x): + return isinstance(x, LayoutBase) + + +class Layout(LayoutBase): + def __init__(self, _shape, _stride=None): + self.shape = _shape + if _stride is None: + self.stride = prefix_product(self.shape) + else: + self.stride = _stride + + # operator == + def __eq__(self, other): + return self.shape == other.shape and self.stride == other.stride + + # operator len(L) (len [rank] like tuples) + def __len__(self): + if is_tuple(self.shape): + return len(self.shape) + else: + return 1 + + # operator () (map coord to idx) + def __call__(self, *args): + """ + Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + OR + Slice the layout and return the sublayout (Coord has an Underscore slice op) + + Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ + """ + if has_none(args): + if len(args) == 1: + return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) + else: + return Layout(slice_(args, self.shape), slice_(args, self.stride)) + else: + if len(args) == 1: + return crd2idx(args[0], self.shape, self.stride) + else: + return crd2idx(args, self.shape, self.stride) + + # operator [] (get-i like tuples) + def __getitem__(self, i): + if is_tuple(self.shape): + return Layout(self.shape[i], self.stride[i]) + else: + assert i == 0 + return Layout(self.shape, self.stride) + + # size(layout) Size of the domain + def size(self): + return product(self.shape) + + # cosize(layout) Size of the codomain + def cosize(self): + return self(self.size() - 1) + 1 + + # print and str + def __str__(self): + return f"{self.shape}:{self.stride}" + + # error msgs and representation + def __repr__(self): + return f"Layout({self.shape},{self.stride})" + + +# Make Layout from a list of layouts (each layout it's own mode in the result) +def make_layout(*layouts): + if len(layouts) == 1 and not is_layout(layouts[0]): + layouts = layouts[0] + + shape, stride = zip(*((a.shape,a.stride) for a in layouts)) + return Layout(shape, stride) + + +# Size of the domain +def size(layout): + if is_layout(layout): + return layout.size() + return product(layout) + + +# Size of the codomain +def cosize(layout): + return layout.cosize() + + +# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function +def coalesce(layout, profile=None): + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))), + (layout[i] for i in range(len(profile),len(layout))))) + + result_shape = [1] + result_stride = [0] + for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): + # skip their shape-1s + if shape == 1: + continue + # replace our shape-1 with anything + elif result_shape[-1] == 1: + result_shape[-1] = shape + result_stride[-1] = stride + # merge modes if the shape*stride match + elif result_shape[-1] * result_stride[-1] == stride: + result_shape[-1] = result_shape[-1] * shape + # append a new mode + else: + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them +def filter(layout, profile=None): + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))), + (layout[i] for i in range(len(profile),len(layout))))) + + result_shape = [] + result_stride = [] + for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): + # skip their shape-1s and stride-0s + if not (shape == 1 or stride == 0): + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 0: + return Layout(1,0) + else: + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout composition +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def composition(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return composition(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + elif is_tuple(layoutB.shape): + return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) + + if layoutB.stride == 0: + return Layout(layoutB.shape, 0) + else: + result_shape = [] + result_stride = [] + rest_shape = layoutB.shape + rest_stride = layoutB.stride + flat_A = coalesce(layoutA) + for (curr_shape, curr_stride) in zip(flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]): + assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 + new_shape = min(max(1, curr_shape // rest_stride), rest_shape) + + if new_shape != 1: + result_shape.append(new_shape) + result_stride.append(rest_stride * curr_stride) + + rest_shape = rest_shape // new_shape + rest_stride = -(-rest_stride // curr_shape) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride) + + if rest_shape != 1 or len(result_shape) == 0: + result_shape.append(rest_shape) + result_stride.append(rest_stride * flatten(flat_A.stride)[-1]) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout complement +def complement(layout, max_idx=1): + if is_int(layout): + return complement(Layout(layout)) + + result_shape = [] + result_stride = [] + current_idx = 1 + + sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) + for (stride, shape) in sorted_DS: + if stride == 0 or shape == 1: + continue + + in_bound = current_idx <= shape * stride + # To support symbolic value which can't be evaluated now + assert (type(in_bound) is not bool) or in_bound + + result_shape.append(stride // current_idx) + result_stride.append(current_idx) + current_idx = shape * stride + + result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div + result_stride.append(current_idx) + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout right inverse +def right_inverse(layout): + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + + result_shape = [] + result_stride = [] + current_idx = 1 + + flat_shape = flatten(layout.shape) + flat_stride = flatten(layout.stride) + sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) + for (stride,shape,rstride) in sorted_DSA: + if shape == 1: + continue + if current_idx != stride: + break + + result_shape.append(shape) + result_stride.append(rstride) + current_idx = shape * stride + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout left inverse +def left_inverse(layout): + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + return right_inverse(make_layout(layout, complement(layout))) + + +# Split a layout by the composition of B and the "rest" +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_divide(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + + return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))) + + +# Reproduce a layoutA over a layoutB +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_product(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + + return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB)); + + +# Gather the modes from a hierarchical logical_divide or logical_product +def hier_unzip(splitter, layoutA, layoutB): + if layoutB is None: + return make_layout(Layout(1,0), layoutA) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + # A layout with shape ((A,a),(B,b),(C,c)) + split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB))) + # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) + return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))), + make_layout(chain((split[i][1] for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA)))))) + + # splitter must return a rank-2 layout + return splitter(layoutA, layoutB) + + +# Apply logical divide hierarchically and gather the split modes into two modes +def zipped_divide(layoutA, layoutB): + return hier_unzip(logical_divide, layoutA, layoutB) + + +# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode +def tiled_divide(layoutA, layoutB): + result = zipped_divide(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) + + +# Apply logical product hierarchically and gather the split modes into two modes +def zipped_product(layoutA, layoutB): + return hier_unzip(logical_product, layoutA, layoutB) + + +# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode +def tiled_product(layoutA, layoutB): + result = zipped_product(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) + + +def slice_and_offset(crd: tuple, + layout: Layout): + return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), + crd2idx(crd, layout.shape, layout.stride)) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..308aee0c3838a82c4de53833fb8a36950b30f62d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py @@ -0,0 +1,129 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Methods for layout swizzling +""" + +from .layout import * + + +def shiftr(a, s): + return a >> s if s > 0 else shiftl(a, -s) + + +def shiftl(a, s): + return a << s if s > 0 else shiftr(a, -s) + + +## A generic Swizzle functor + # 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + # ^--^ Base is the number of least-sig bits to keep constant + # ^-^ ^-^ Bits is the number of bits in the mask + # ^---------^ Shift is the distance to shift the YYY mask + # (pos shifts YYY to the right, neg shifts YYY to the left) + # + # e.g. Given + # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + # the result is + # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY + # +class Swizzle: + def __init__(self, bits, base, shift): + assert bits >= 0 + assert base >= 0 + assert abs(shift) >= bits + self.bits = bits + self.base = base + self.shift = shift + bit_msk = (1 << bits) - 1 + self.yyy_msk = bit_msk << (base + max(0,shift)) + self.zzz_msk = bit_msk << (base - min(0,shift)) + + # operator () (transform integer) + def __call__(self, offset): + return offset ^ shiftr(offset & self.yyy_msk, self.shift) + + # Size of the domain + def size(self): + return 1 << (self.bits + self.base + abs(self.shift)) + + # Size of the codomain + def cosize(self): + return self.size() + + # print and str + def __str__(self): + return f"SW_{self.bits}_{self.base}_{self.shift}" + + # error msgs and representation + def __repr__(self): + return f"Swizzle({self.bits},{self.base},{self.shift})" + + +class ComposedLayout(LayoutBase): + def __init__(self, layoutB, offset, layoutA): + self.layoutB = layoutB + self.offset = offset + self.layoutA = layoutA + + # operator == + def __eq__(self, other): + return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA + + # operator len(L) (len [rank] like tuples) + def __len__(self): + return len(self.layoutA) + + # operator () (map coord to idx) + def __call__(self, *args): + return self.layoutB(self.offset + self.layoutA(*args)) + + # operator [] (get-i like tuples) + def __getitem__(self, i): + return ComposedLayout(self.layoutB, self.offset, self.layoutA[i]) + + # size(layout) Size of the domain + def size(self): + return size(self.layoutA) + + # cosize(layout) Size of the codomain + def cosize(self): + return cosize(self.layoutB) + + # print and str + def __str__(self): + return f"{self.layoutB} o {self.offset} o {self.layoutA}" + + # error msgs and representation + def __repr__(self): + return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..834f7e5411f5c2a4e218f9ce8a4f0a229d039710 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from abc import ABC + + +class Integer(ABC): + @classmethod + def __subclasshook__(cls, c): + if c in [bool, float]: + return False + + return issubclass(c, int) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..acc0c46e540735443a4943908852010a80d02187 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py @@ -0,0 +1,74 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +import copy +import os +import setuptools +from setuptools import setup +from setuptools.command.build_ext import build_ext + +import setup_pycute +import setup_library + + +# Install cutlass_library package +setup_library.perform_setup() + + +# Install the PyCuTe package +setup_pycute.perform_setup() + + +setup( + name='cutlass_cppgen', + version='4.2.0', + description='CUTLASS Pythonic Interface', + package_dir={'': '.'}, + packages=[ + 'cutlass_cppgen', + 'cutlass_cppgen.emit', + 'cutlass_cppgen.op', + 'cutlass_cppgen.utils', + 'cutlass_cppgen.backend', + 'cutlass_cppgen.backend.utils' + ], + setup_requires=['pybind11'], + install_requires=[ + 'bfloat16', + 'cuda-python>=11.8.0', + 'pybind11', + 'scikit-build', + 'treelib', + 'pydot' + ] +) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_library.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_library.py new file mode 100644 index 0000000000000000000000000000000000000000..c56d6b5556fea2d5e56209b13f5b95e487ca22fb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_library.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from setuptools import setup + + +def perform_setup(): + setup( + name='cutlass_library', + version='4.2.1', + description='CUTLASS library generation scripts', + packages=['cutlass_library'] + ) + + +if __name__ == '__main__': + perform_setup() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py new file mode 100644 index 0000000000000000000000000000000000000000..0bad050fcade8b26d33043abbb0f8226be7d816c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from setuptools import setup + + +def perform_setup(): + setup( + name='pycute', + version='4.2.1', + description='Python implementation of CuTe', + packages=['pycute'], + ) + + +if __name__ == '__main__': + perform_setup() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py new file mode 100644 index 0000000000000000000000000000000000000000..852c0277ebae2fce7e0b083ce2f497a2c828256f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for defining Conv2D problem sizes for testing. + +This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h +""" + +from cutlass_library import ConvMode + +import cutlass_cppgen +from cutlass_cppgen.shape import Conv2DProblemSize + + +class TestbedConv2dProblemSizes: + def __init__(self, minimum_channel_size: int): + conv2d_default_sizes = self.initialize_conv2d_default_sizes(minimum_channel_size) + conv2d_rigorous_sizes = self.initialize_conv2d_rigorous_sizes(minimum_channel_size) + conv2d_resnet50_sizes = self.initialize_conv2d_resnet50_sizes(1) + conv2d_resnet50_sizes_perf = self.initialize_conv2d_resnet50_sizes(34) + grouped_sizes = self.initialize_conv2d_grouped_sizes() + + # Filter all problems + self.all = [] + for size_list in [conv2d_default_sizes, conv2d_rigorous_sizes, conv2d_resnet50_sizes, conv2d_resnet50_sizes_perf, grouped_sizes]: + for size in size_list: + if (size.C // size.groups) % minimum_channel_size == 0: + self.all.append(size) + + + def initialize_conv2d_default_sizes(self, minimum_channel_size): + # Small input size x stride (1,1) + # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + + conv2d_default_sizes = [] + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 1, 1, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 1, 8, minimum_channel_size, + 8, 1, 3, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 7, 8, minimum_channel_size, + 8, 3, 3, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 7, 9, minimum_channel_size, + 8, 4, 4, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 2, 7, 9, minimum_channel_size, + 8, 5, 5, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 6, 5, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 6, 6, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 7, 7, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + ############################################## + # Small input size x stride (2,2) + # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + ############################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 11, 7, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 11, 7, minimum_channel_size, + 8, 3, 3, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 11, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 17, 19, minimum_channel_size, + 16, 2, 2, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 5, minimum_channel_size, + 16, 3, 3, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 17, 8, + 24, 3, 3, 8, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 21, 8, + 24, 3, 3, 8, + 1, 1, + 3, 3, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 20, 24, 8, + 40, 3, 3, 8, + 3, 3, + 3, 3, + 1, 1, + )) + + ########################################## + # Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 15, 19, 160, + 224, 1, 1, 160, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 19, 37, 160, + 224, 3, 3, 160, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 16, 16, 160, + 224, 2, 3, 160, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 21, 128, + 224, 3, 3, 128, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 29, 37, 160, + 224, 5, 5, 160, + 2, 2, + 1, 1, + 1, 1, + )) + + ########################################## + # C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 15, 19, 32 + minimum_channel_size, + 96, 3, 3, 32 + minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 16, 24, 64 + minimum_channel_size, + 96, 3, 3, 64 + minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + ########################################## + # Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 16, 288, + 160, 5, 5, 288, + 2, 2, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 55, 51, 256, + 512, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 71, 80, 32, + 64, 5, 5, 32, + 2, 2, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 224, 224, 8, + 64, 7, 7, 8, + 3, 3, + 2, 2, + 1, 1, + )) + + ########################################## + # Medium input size stride (3, 3), filter (3, 3), non-default padding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 23, 256, + 512, 3, 3, 256, + 0, 0, + 3, 3, + 1, 1, + )) + + ########################################## + # Medium input size padding > stride, asymmetric filter, padding and striding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 31, 256, + 512, 3, 3, 256, + 5, 7, + 3, 4, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 35, 256, + 512, 7, 5, 256, + 11, 7, + 3, 5, + 1, 1, + )) + + ########################################## + # Medium input size *mixed* stride (1, 2) and (2, 1), + # filter (3, 3), default padding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 27, 256, + 512, 3, 3, 256, + 1, 1, + 1, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 27, 256, + 512, 3, 3, 256, + 1, 1, + 2, 1, + 1, 1, + )) + + ######################################/ + # Additional input size + ######################################/ + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 28, 28, 256, + 256, 2, 2, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 32, 32, 16, + 32, 3, 3, 16, + 1, 1, + 6, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 32, 24, 32, 32, + 32, 1, 2, 32, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 4, 2, 3, 256, + 328, 3, 5, 256, + 1, 1, + 1, 1, + 1, 1, + )) + return conv2d_default_sizes + + # Add a few large and rigorous convolution problem sizes + def initialize_conv2d_rigorous_sizes(self, minimum_channel_size): + sizes = [] + if False: + sizes.append(Conv2DProblemSize.from_sizes( + (1, 124, 224, 2 * minimum_channel_size), + (24, 7, 7, 2 * minimum_channel_size), + )) + + sizes.append(Conv2DProblemSize.from_sizes( + (1, 233, 35, minimum_channel_size), + (24, 7, 5, minimum_channel_size), + )) + return sizes + + # Add resent50 layers to unit testing sizes + def initialize_conv2d_resnet50_sizes(self, batch_size): + conv2d_problem_vector = [] + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 256, 1, 1, 64, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 64, 1, 1, 64, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 64, 3, 3, 64, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 64, 1, 1, 256, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 512, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 128, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 128, + 128, 3, 3, 128, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 128, + 512, 1, 1, 128, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 128, 1, 1, 512, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 1024, 1, 1, 512, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 256, 1, 1, 512, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 256, + 256, 3, 3, 256, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 256, + 1024, 1, 1, 256, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 256, 1, 1, 1024, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 2048, 1, 1, 1024, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 512, 1, 1, 1024, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 512, + 512, 3, 3, 512, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 512, + 2048, 1, 1, 512, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 2048, + 512, 1, 1, 2048, + 0, 0, + 1, 1, + 1, 1, + )) + + return conv2d_problem_vector + + def initialize_conv2d_grouped_sizes(self): + threadblock_n = 128 + threadblock_k = 32 + + sizes = [] + ########################################## + # One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + # One CTA calculates a single group + ########################################## + for cta_per_group_k in range(1, 4): + for groups in range(2, 5): + conv_k = cta_per_group_k * threadblock_n * groups + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 2 * groups, + conv_k, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + groups + )) + + # Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k, + threadblock_n * 2, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + sizes.append(Conv2DProblemSize( + 1, 56, 56, 696, + 768, 3, 3, 232, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 3 + )) + sizes.append(Conv2DProblemSize( + 1, 14, 14, 1392, + 1536, 3, 3, 232, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 3 + )) + + ########################################## + # One CTA calculate multiple groups: CTA::N % k_per_group = 0 + ########################################## + + # 2 groups per CTA + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 4, + threadblock_n, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + # 2 groups per CTA and partial gemm_k + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k, + threadblock_n, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + # 4 groups per CTA + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 8, + threadblock_n // 2, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 4 + )) + + # 4 groups per CTA and partial gemm_k + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 2, + threadblock_n // 2, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 4 + )) + + return sizes diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..f77a0ec831be087bd3badc929eee955f0b37c489 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py @@ -0,0 +1,146 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for Conv2d opreations on SM80 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from conv2d_test_utils import * + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is invalid for SM80 tests.') +class Conv2dSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +conv_problems = get_conv_problems() + + +# Tests for optimized & analytic +for conv_kind in ["fprop", "wgrad", "dgrad"]: + # F16, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1]) + # F16, tensor op + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, analytic iterator + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic") + # F16, tensor op, f32 output + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, different tile description + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]) + # F32, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1]) + # Tf32, tensorop + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 16], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8] + ) + # Split-K + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial", + split_k_slices=2) + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel", + split_k_slices=5) + # Swizzling functor + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4) + +# Tests for few channels and fixed channels +# F16, tensor op, few channels +for c, tb, stage, inst in zip([2, 1], + [[128, 128, 64], [128, 128, 32]], + [3, 2], + [[16, 8, 16], [16, 8, 8]]): + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=tb, + warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels" + ) +# F16, tensor op, fixed channels +for c in [8, 4, 2]: + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels" + ) + +# Test activations +for activation in ["relu", "leaky_relu"]: + for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]): + add_test( + Conv2dSm80, cc, "fprop", conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode, + split_k_slices=split_k_slices, activation=activation) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc4542cd5ccf72341f7db3c7947d481b032926d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py @@ -0,0 +1,428 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for Conv2d tests. +""" + +from cutlass_library import SubstituteTemplate +import torch + +import cutlass_cppgen +from cutlass_library import ( + ConvKind, + ConvMode, + DataType, + DataTypeNames, + EpilogueScheduleSuffixes, + KernelScheduleSuffixes, + LayoutType, + OpcodeClassNames, + ShortDataTypeNames, + ShortLayoutTypeNames, + SplitKMode, +) +from cutlass_cppgen.shape import Conv2DProblemSize +from cutlass_cppgen.utils.datatypes import numpy_type, torch_type + +from conv2d_problem_sizes import TestbedConv2dProblemSizes + + +def get_name_conv2d( + arch, + conv_kind, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm, + swizzle, + split_k_mode, + split_k_slices, + activation +): + """ + Generates a procedural name for a test case for conv2d + + :param arch: compute capability of kernel being generated + :type arch: int + :param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad) + :type conv_kind: str + :param iterator_algorithm: the iterator algorithm applied + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param element_accumulator: data type used in accumulation + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param stride_support: stride support of dgrad + :param alignment: int + :type alignment: int + + :return: str + """ + if iterator_algorithm is None: + iterator_algorithm = "AUTO" + if swizzle is None: + swizzle = 1 + name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}" + + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "conv_kind": conv_kind, + "iter_alg": iterator_algorithm, + "eA": DataTypeNames[element], + "eB": DataTypeNames[element], + "eC": DataTypeNames[element_output], + "opclass": opclass, + "acc": DataTypeNames[element_accumulator], + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "wM": str(threadblock_shape[0] // warp_count[0]), + "wN": str(threadblock_shape[1] // warp_count[1]), + "wK": str(threadblock_shape[2] // warp_count[2]), + "IM": str(instruction_shape[0]), + "IN": str(instruction_shape[1]), + "IK": str(instruction_shape[2]), + "stages": str(stages), + "swizzle": str(swizzle), + "split_k_mode": split_k_mode, + "split_k_slices": str(split_k_slices), + "activation": activation + } + ) + + +def conv2d_few_channel_problemsizes(channels): + problem_sizes = [ + Conv2DProblemSize( + 1, 8, 8, channels, + 16, 3, 3, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 16, 16, channels, + 16, 3, 3, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 16, 16, channels, + 16, 7, 7, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 32, 7, 7, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 7, 7, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 5, 5, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 5, 5, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + ] + + return problem_sizes + + +def validate_problem_size(ps, conv_kind, split_k_slices): + P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1 + Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1 + if P != ps.P or Q != ps.Q: + return False + + # Split-K (serial or parallel) is not supported for strided dgrad + if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1): + return False + return True + + +class Conv2dLauncherFrontend: + def __init__(self, plan: cutlass_cppgen.Conv2d, seed: int = 80, backend="numpy"): + self.operation = plan + self.conv_kind = plan.conv_kind + self.seed = seed + self.backend = backend + + self.dtype_A = plan._element_a + self.dtype_B = plan._element_b + self.dtype_C = plan._element_c + self.dtype_acc = plan._element_accumulator + self.layout_A = LayoutType.TensorNHWC + self.layout_B = LayoutType.TensorNHWC + self.layout_C = LayoutType.TensorNHWC + self.layout_D = LayoutType.TensorNHWC + + self.element_compute = DataType.f32 + + if self.dtype_A in [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.bf16]: + self.rand_max = 1 + else: + self.rand_max = 4 + self.activation = plan.activation + + def uniform_init(self, size, dtype): + tensor = torch.ceil( + torch.empty(size=size, dtype=torch_type(dtype), device="cuda").uniform_(-self.rand_max - 0.5, self.rand_max - 0.5) + ).to(memory_format=torch.channels_last) + return tensor + + def reference(self, ps, A, B, C, alpha, beta, activation): + if self.conv_kind == ConvKind.Fprop: + torch_result = alpha * torch.ops.aten.conv2d( + A, + B, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w) + ) + beta * C + elif self.conv_kind == ConvKind.Dgrad: + torch_result = alpha * torch.nn.grad.conv2d_input( + (ps.N, ps.C, ps.H, ps.W), + B, + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + elif self.conv_kind == ConvKind.Wgrad: + torch_result = alpha * torch.nn.grad.conv2d_weight( + B, + (ps.K, ps.C, ps.R, ps.S), + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + else: + raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.") + + if activation == cutlass_cppgen.backend.epilogue.relu: + torch_result = torch.nn.functional.relu(torch_result) + elif activation == cutlass_cppgen.backend.epilogue.leaky_relu: + torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5) + return torch_result + + def run(self, ps, split_k_mode=SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0): + if self.conv_kind == ConvKind.Fprop: + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + elif self.conv_kind == ConvKind.Dgrad: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + elif self.conv_kind == ConvKind.Wgrad: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + else: + raise Exception(f"Conv kind {self.conv_kind} is not supported") + + torch.manual_seed(self.seed) + + tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A) + tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B) + tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C) + tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last) + args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w), + alpha=alpha, beta=beta, + split_k=(split_k_mode, split_k_slices)) + + args.sync() + + tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation) + + torch.cuda.synchronize() + passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06) + + return passed + + +def add_test( + cls, + cc, + conv_kind, + problem_sizes, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm=None, + swizzle=None, + split_k_mode="serial", + split_k_slices=1, + activation = "identity" +): + """Create a test-running function with the given specification""" + test_name = get_name_conv2d( + cc, conv_kind, element, element_accumulator, + element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages, + iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation) + + def run(self): + # Create the plan + plan = cutlass_cppgen.Conv2d( + kind=conv_kind, + element=element, + element_accumulator=element_accumulator, + element_C=element_output, + element_D=element_output + ) + + # Set the opclass + plan.opclass = opclass + # Set the tile description + td = { + "threadblock_shape": threadblock_shape, + "warp_count": warp_count, + "stages": stages, + "instruction_shape": instruction_shape, + } + + plan.tile_description = td + # Set iterator algorithm + if iterator_algorithm is not None: + plan.iterator_algorithm = iterator_algorithm + # Set swizzling functor + if swizzle is not None: + plan.swizzling_stride = swizzle + + if activation != "identity": + if activation == "leaky_relu": + plan.activation = (cutlass_cppgen.epilogue.leaky_relu, 0.5) + else: + plan.activation = getattr(cutlass_cppgen.epilogue, activation) + + conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch") + + for ps in problem_sizes: + if not validate_problem_size(ps, conv_kind, split_k_slices): + continue + + self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0)) + + setattr(cls, test_name, run) + + return run + + +def get_conv_problems(): + # 64: minimum channel size + conv_problems = TestbedConv2dProblemSizes(64).all + + # Insert alignment 4 & 2 tests + conv_problems += [ + Conv2DProblemSize( + 1, 4, 4, 12, + 8, 3, 3, 12, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 4, 4, 14, + 8, 3, 3, 14, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 23, 56, 98, + 128, 3, 3, 98, + 4, 5, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + ] + + return conv_problems diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..d892b5df047d5121345d902a77aadf2256b4c3b5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'conv2d_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d4c52a9f75fb4c3bc809947bf48ba85356ec70 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py @@ -0,0 +1,309 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests emitting a CUTLASS kernel to a PyTorch CUDA extension +""" + +import random +import tempfile +import unittest + +from cutlass_library import ConvMode + +import cutlass_cppgen + +if cutlass_cppgen.utils.datatypes.is_torch_available(): + import torch + + +def _initialize(dtype, M: int, N: int, K: int): + """ + Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K + + :param dtype: data type of tensors + :param M: M dimension of GEMM problem + :type M: int + :param N: N dimension of GEMM problem + :type N: int + :param K: N dimension of GEMM problem + :type K: int + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + sizes = [(M, K), (K, N), (M, N), (M, N)] + return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes] + + +def _generate_problems(dtype, num): + """ + Utility function to generate `num` GEMMs of random sizes + + :param dtype: data type of tensors + :param num: number of GEMMs to generate + :type num: int + + :return: lists of A, B, C, and D tensors + :rtype: list + """ + valid_sizes = [128, 256, 512, 1024] + As, Bs, Cs, Ds = [], [], [], [] + for _ in range(num): + M, N, K = [random.choice(valid_sizes) for _ in range(3)] + A, B, C, D = _initialize(dtype, M, N, K) + As.append(A) + Bs.append(B) + Cs.append(C) + Ds.append(D) + return As, Bs, Cs, Ds + +def _generate_conv2d_problem(conv_kind, dtype, ps): + """ + Utility function to generate conv2d inputs + + :param conv_kind: kind of convolution + :type conv_kind: str + :param dtype: data type of tensors + :param problem_size: the conv2d problem size + :type problem_size: cutlass_cppgen.shape.Conv2DProblemSize + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + if conv_kind == "fprop": + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + elif conv_kind == "dgrad": + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + else: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + sizes = [tensor_A_size, tensor_B_size, tensor_C_size] + return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes] + + +@unittest.skipIf(not cutlass_cppgen.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests') +class PyTorchExtensionTest(unittest.TestCase): + + def test_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.Gemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + A, B, C, _ = _initialize(dtype, 1024, 256, 512) + + D_ref = A @ B + D = mod.run(A, B) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0, 0.0) + assert torch.allclose(D, D_ref) + + alpha = 2.0 + beta = -1.0 + D_ref = (A @ B) * alpha + (beta * C) + D = mod.run(A, B, C, alpha, beta) + assert torch.allclose(D, D_ref) + + def test_grouped_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + As, Bs, Cs, _ = _generate_problems(dtype, 50) + + def check_all(X, Y): + for x, y in zip(X, Y): + assert torch.allclose(x, y) + + Ds_ref = [a @ b for a, b in zip(As, Bs)] + Ds = mod.run(As, Bs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0, 0.0) + check_all(Ds, Ds_ref) + + alpha = 2.0 + beta = -1.0 + Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)] + Ds = mod.run(As, Bs, Cs, alpha, beta) + check_all(Ds, Ds_ref) + + def test_conv2d_fprop(self): + torch.manual_seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) + plan.activation = "relu" + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + + D_ref = alpha * torch.ops.aten.conv2d( + A, B, stride=stride, padding=padding + ) + beta * C + D_ref = torch.nn.functional.relu(D_ref) + D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) + + + def test_conv2d_dgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W) + D_ref = alpha * torch.nn.grad.conv2d_input( + input_size, B, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, ) + + assert torch.allclose(D, D_ref) + + def test_conv2d_wgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S) + D_ref = alpha * torch.nn.grad.conv2d_weight( + B, weight_size, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..5467469e74e05573fb297b009914e0980e5ab222 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -0,0 +1,198 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Unit test for compute node in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen import swizzle + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTCompute(EVTTestCaseBase): + + def test_arith(self): + """ + Test Arithmatic op + """ + def evt_arith_compute(accum, C, alpha, beta, gamma): + D = ((accum + C) * alpha - gamma) / beta + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_arith_compute, example_inputs) + input_keys = ["C", "alpha", "beta", "gamma"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_func_call(self): + """ + Test Function call + """ + def evt_func_call(accum, C, alpha, beta, gamma): + D = multiply_add(relu(accum + alpha) + C, beta, gamma) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call, example_inputs) + input_keys = ["C", "alpha", "beta", "gamma"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_func_call2(self): + """ + Test Function call + """ + + def evt_func_call2(accum, C, alpha, beta): + D = maximum(alpha * accum + beta * C, 0.0) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call2, example_inputs) + input_keys = ["C", "alpha", "beta"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_tanh(self): + """ + Test Tanh op + """ + def evt_tanh(accum): + D = tanh(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_tanh, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_sigmoid(self): + """ + Test Sigmoid op + """ + def evt_sigmoid(accum): + D = sigmoid(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_sigmoid, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_gelu(self): + """ + Test GELU op + """ + def evt_gelu(accum): + D = gelu(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_gelu, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_exp(self): + """ + Test Exp op + """ + def evt_exp(accum): + D = exp(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_exp, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a7b7f7a336dce0651f299d26b17df04952be99 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py @@ -0,0 +1,173 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLayout(EVTTestCaseBase): + + def test_permute_1(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D_permute = F_permute + permute(C, indices=(0, 2, 1)) + D = permute(D_permute, indices=(0, 2, 1)) + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_2(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, n, m)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, n, m)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_3(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(1, 0, 2)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (m, l, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (m, l, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_reshape(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + E_reshape = reshape(TensorE, new_shape=(512, 1)) + D = F + E_reshape + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (16, 32)), + "D": self.fake_tensor(self.element, (self.l, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + def test_reshape2(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + F_reshape = reshape(F, new_shape=(2, 3, 512, 256)) + D = F_reshape + TensorE + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)), + "D": self.fake_tensor(self.element, (2, 3, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..57a5c6bb17bb44bf294cc7a6a749c706601034f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py @@ -0,0 +1,142 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for load nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLoad(EVTTestCaseBase): + + def test_tensor_load(self): + """ + Load extra tensor with shape [m, n] + """ + def evt_tensor_load(accum, C, aux, aux_batch): + D = accum + C + aux + aux_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "aux": self.fake_tensor(self.element, (m, n)), + "aux_batch": self.fake_tensor(np.float32, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs) + input_keys = ["C", "aux", "aux_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_broadcast(self): + """ + Load extra tensor with shape [1, n] + """ + def evt_row_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (n,)), + "bias_batch": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_column_broadcast(self): + """ + Load extra tensor with shape [m, 1] + """ + def evt_column_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m, 1)), + "bias_batch": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_broadcast(self): + """ + Load extra tensor with shape [1, 1] + """ + def evt_scalar_broadcast(accum, C, alpha, alpha_batch): + D = accum + C + alpha + alpha_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs) + input_keys = ["C", "alpha", "alpha_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..30dc8fe0d5ec413f1da57a8fa0875ed5e7baa887 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py @@ -0,0 +1,319 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unittest for mixed types of nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTMixed(EVTTestCaseBase): + + def test_same_variable_used_multiple_times(self): + """ + The same variable z0 is used multiple times + """ + def evt_aux_store(accum): + z0 = relu(accum) + D = z0 + z0 + return z0, D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "z0": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["accum"] + result_keys = ["z0", "D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_no_lca(self): + """ + The same variable z0 is used multiple times + """ + def evt_no_lca(accum, bias): + E = relu(accum) + F = E + bias + tmp_2 = E + 2 + D = tmp_2 + E + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)), + } + + launcher = EVTTestBed(self.element, evt_no_lca, example_inputs) + input_keys = ["accum", "bias"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + if device_cc() == 80: + alignments = [2, 4, 8] + else: + # Sm90 EVT currently only supports 128-bit alignment + alignments = [8,] + for align in alignments: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_float(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for align in [3, 2, 4]: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(np.float32, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(np.float32, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(np.float32, (l, m, n)), + "cbias": self.fake_tensor(np.float32, (m, 1)), + "rbias": self.fake_tensor(np.float32, (n,)), + "D": self.fake_tensor(np.float32, (l, m, n)), + "F": self.fake_tensor(np.float32, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (n,)), + "E_col_max": self.fake_tensor(np.float32, (m, 1)) + } + launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stage2(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_partition_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + tile_description = { + "threadblock_shape": [128, 128, 64], + "warp_count": [2, 2, 2] + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stream_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + # High per-sm occupancy tile_description + tile_description = { + "threadblock_shape": [128, 128, 32], + "warp_count": [2, 2, 1], + "stages": 3 + } + tds = [None, tile_description] + for td in tds: + for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]): + if l == 1: + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + else: + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + if td is not None: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + tile_description=td, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + else: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag_no_batch(self): + def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, _ in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, 1) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..b47f11e4f3bde3499948ae68b1b5bb79347f0fd1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py @@ -0,0 +1,180 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTStore(EVTTestCaseBase): + + @unittest.skipIf(device_cc() != 90, "This test is only for CC 90") + def test_invalid_store(self): + """ + Test invalid store + """ + def evt_invalid_store(accum): + D = accum + F = D + 1 # D has users, which is not allowed on SM90 or higher + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)) + } + with self.assertRaisesRegex( + RuntimeError, + r"On SM90 or higher, D is expected to be a output node with 0 users " + r"to enable smem reuse between C and D, but got 1" + ): + launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs) + + break # Only need to test once + + def test_aux_store(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_aux_store(accum, alpha, C): + F = alpha * accum + D = F + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_col_reduce(self): + """ + Reduction [m, n] -> [m, 1] + """ + def evt_row_reduce(accum, alpha, C): + acc_row_max = max(accum, dim=[2,]) + F = alpha * accum + F_row_max = max(F, dim=[0, 2]) + D = F + C + return D, F_row_max, acc_row_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (m, 1)), + "acc_row_max": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_row_max", "acc_row_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_reduce(self): + """ + Reduction [m, n] -> [n] + """ + def evt_col_reduce(accum, alpha, C): + acc_col_max = max(accum, dim=[1,]) + F = alpha * accum + F_col_max = max(F, dim=[0, 1]) + D = F + C + return D, F_col_max, acc_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_col_max": self.fake_tensor(np.float32, (n,)), + "acc_col_max": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_col_max", "acc_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_reduce(self): + """ + Reduction [m, n] -> [1,] + """ + def evt_scalar_reduce(accum, alpha, C): + acc_max = max(accum, dim=[1, 2]) + F = alpha * accum + F_max = max(F, dim=[0, 1, 2]) + D = F + C + return D, F_max, acc_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "acc_max": self.fake_tensor(np.float32, (l, 1, 1)), + "F_max": self.fake_tensor(np.float32, (1,)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_max", "acc_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb84e2e8c85e602b45b9ee18ce324accd3a32cd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'evt_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..62d375d856ffaef6be50b39b76121e0eb78a7465 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py @@ -0,0 +1,235 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Testbed classes of EVT +""" + +import torch +import unittest + +import cutlass_cppgen +from cutlass_cppgen import Tensor +import cutlass_cppgen.backend.evt +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import torch_type +from cutlass_cppgen.utils.profiler import CUDAEventProfiler + + +class EVTReferenceModule: + def __init__(self, layout_A, layout_B, layout_C, epilogue_visitor): + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.epilogue_visitor = epilogue_visitor + + def run(self, A, B, C, problem_size, alpha, beta, batch=1): + if self.layout_A == cutlass_cppgen.LayoutType.RowMajor: + A_row = A.view((batch, problem_size.m, problem_size.k)) + else: + A_col = A.view((batch, problem_size.k, problem_size.m)) + A_row = torch.permute(A_col, (0, 2, 1)) + + if self.layout_B == cutlass_cppgen.LayoutType.RowMajor: + B_row = B.view((batch, problem_size.k, problem_size.n)) + else: + B_col = B.view((batch, problem_size.n, problem_size.k)) + B_row = torch.permute(B_col, (0, 2, 1)) + + if self.layout_C == cutlass_cppgen.LayoutType.RowMajor: + C_row = C.view((batch, problem_size.m, problem_size.n)) + else: + C_col = C.view((batch, problem_size.n, problem_size.m)) + C_row = torch.permute(C_col, (0, 2, 1)) + + out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta + + if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor: + out = torch.permute(out_row, (0, 2, 1)) + else: + out = out_row + + return torch.flatten(out) + + def __call__(self, A, B, C, problem_size, batch=1, epilogue_args=None): + # Running the mainloop + accum = self.run( + A, B, C, problem_size, 1.0, 0.0, batch=batch + ).reshape(batch, problem_size.m, problem_size.n) + + # Running the epilogue + epilogue_args["accum"] = accum + references = self.epilogue_visitor(**epilogue_args) + + # Return the results + if not isinstance(references, tuple): + references = (references,) + return references + + +class EVTTestBed: + """ + Epilogue Visitor Testbed + """ + def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None: + self.element = element + layout = cutlass_cppgen.LayoutType.RowMajor + self.example_inputs = example_inputs + + # Create the Gemm plan + self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32) + + if "tile_description" in kwargs: + self.plan.tile_description = kwargs["tile_description"] + + if "swizzling_functor" in kwargs: + self.plan.swizzling_functor = kwargs["swizzling_functor"] + + # Compile the epilogue visitor + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs) + if "epilogue_stages" in kwargs: + epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"] + self.plan.epilogue_visitor = epilogue_visitor + + # Reference model + self.reference_fn = EVTReferenceModule(layout, layout, layout, epilogue_visitor) + + self.profile = profile + + def get_torch_tensor(self, shape, dtype=None, fill=None): + if dtype is None: + dtype = self.element + + dtype = torch_type(dtype) + if fill is None: + return torch.ceil( + torch.empty(size=shape, dtype=dtype, device="cuda").uniform_(-4.5, 3.5) + ) + else: + return torch.full(shape, fill, dtype=dtype, device="cuda") + + def verify(self, problem_size, input_keys, result_keys, batch_count=1): + """ + Verify the results + """ + problem_size = GemmCoord(*problem_size) + + # Initiate the GEMM arguments + tensor_A = self.get_torch_tensor((batch_count, problem_size.m, problem_size.k)) + tensor_B = self.get_torch_tensor((batch_count, problem_size.k, problem_size.n)) + + # Initialize the epilogue args + epilogue_args = {} + for key in self.example_inputs.keys(): + if key in input_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element) + else: + epilogue_args[key] = tensor + elif key in result_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + if "max" in key: + fill = -1000 + else: + fill = 0 + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element, fill=fill) + else: + epilogue_args[key] = tensor + + tensor_D = epilogue_args["D"] + if "C" in epilogue_args: + tensor_C = epilogue_args["C"] + else: + tensor_C = tensor_D + # Run the device kernel + self.plan.run(tensor_A, tensor_B, tensor_C, tensor_D, visitor_args=epilogue_args) + + # Run the host reference + evt_args_inputs = {} + for key in input_keys: + evt_args_inputs[key] = epilogue_args[key] + + reference_results = self.reference_fn( + tensor_A, tensor_B, tensor_C, problem_size, batch_count, evt_args_inputs) + + # Compare the results + for result, ref in zip(result_keys, reference_results): + assert torch.equal( + epilogue_args[result].flatten(), + ref.masked_fill(torch.isnan(ref), float('inf')).flatten()) + + # Run profile + if self.profile: + profiler = CUDAEventProfiler( + self.plan, 100, 100, tensor_A, tensor_B, tensor_C, tensor_D, + visitor_args = epilogue_args + ) + print(f"Cutlass Python Duration: {profiler()}") + + +class EVTTestCaseBase(unittest.TestCase): + """ + Base class for EVT Unittest + """ + def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None: + super().__init__(methodName) + + self.element = cutlass_cppgen.DataType.f16 + self.l, self.m, self.n, self.k = lmnk + + self.problem_size = (self.m, self.n, self.k) + + torch.random.manual_seed(42) + + def fake_tensor(self, element, shape, stride=None): + if stride is None: + return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor) + else: + return Tensor(element=element, shape=shape, stride=stride) + + def get_problem_sizes(self, alignment, k=None, batch_count=[3,]): + k = k if k else self.k + problem_size_m = [alignment, 512 - 3 * alignment] + problem_size_n = [alignment, 512 - alignment] + if alignment % 8 == 0: + problem_size_m.append(768) + problem_size_n.append(768) + problem_size_l = batch_count + problem_sizes = [] + for m in problem_size_m: + for n in problem_size_n: + for l in problem_size_l: + problem_sizes.append((m, n, k, l)) + + return problem_sizes diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py new file mode 100644 index 0000000000000000000000000000000000000000..155426ab902d1f99eafc7b03c388fc79b4520317 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +High-level tests for running batched GEMMs +""" + +from functools import partial +import logging +from math import prod +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc +import torch + +from utils import LayoutCombination + +cutlass_cppgen.set_log_level(logging.WARNING) + +torch.manual_seed(2023) + + +def pytorch_reference(A, B, C, alpha, beta): + # Get the batch count. Assume that any of A, B, and C + # with a batch dimension ahve matching batch count. Thus, + # we break out of the loop once we have found the first + # tensor containing a batch dimension. + batch_count = (1,) + for tensor in [A, B, C]: + if len(tensor.shape) > 2: + batch_count = tensor.shape[:-2] + break + + int_batch_count = prod(batch_count) + + def add_batch(tensor): + if len(tensor.shape) == 2: + return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1) + else: + return tensor.reshape(-1, tensor.size(-2), tensor.size(-1)) + + # Reshape tensors to have batch dimension + A = add_batch(A) + B = add_batch(B) + C = add_batch(C) + + ret = (torch.bmm(A, B) * alpha) + (C * beta) + reshape_vals = batch_count + C.shape[-2:] + return ret.reshape(*reshape_vals) + + +def initialize(rows, cols, batch): + tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half() + if len(batch) > 0 and prod(batch) > 1: + reshape_vals = batch + (rows, cols) + return tensor.reshape(*reshape_vals) + else: + return tensor.reshape(rows, cols) + + +class GemmF16Batched(unittest.TestCase): + def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool): + M = 512 + N = 256 + K = 128 + alpha = 1. + beta = 2. + + A = initialize(M, K, batch_count if batch_A else (1,)) + B = initialize(K, N, batch_count if batch_B else (1,)) + C = initialize(M, N, batch_count if batch_C else (1,)) + D = initialize(M, N, batch_count) + + plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32) + plan.run(A, B, C, D, alpha, beta) + reference = pytorch_reference(A, B, C, alpha, beta) + assert reference.equal(D) + + def test_batched_ABC(self): + self.run_batched((3,), True, True, True) + self.run_batched((2, 3), True, True, True) + + def test_batched_AB(self): + self.run_batched((3,), True, True, False) + self.run_batched((2, 3), True, True, False) + + def test_batched_AC(self): + self.run_batched((3,), True, False, True) + self.run_batched((2, 3), True, False, True) + + def test_batched_BC(self): + self.run_batched((3,), False, True, True) + self.run_batched((2, 3), False, True, True) + + def test_batched_A(self): + self.run_batched((3,), True, False, False) + self.run_batched((2, 3), True, False, False) + + def test_batched_B(self): + self.run_batched((3,), False, True, False) + self.run_batched((2, 3), False, True, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd26951ec5d8a1eb6cbe38491c64fde2873b9c3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py @@ -0,0 +1,128 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..61aa295b966daf5943e7092572c98ee20143e2b5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py @@ -0,0 +1,146 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype, + warp_count=None, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1]) +add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) +add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) + +# Tests with different cluster shapes +add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1]) + +# Tests for different schedule modes +add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4], + element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) +add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8]) +add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8]) +add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8]) +add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8]) +add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8]) + +# Tests with void-C kernels +add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, + cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..bf662b9208ab2a5343d0fd11106835b7d9a5b2e9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py @@ -0,0 +1,104 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F32 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f32 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..3075ddf74bf2a119759ca1a3e47c0815f4b0923c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf36fc77436fef22882e98c752b7a599cf7fb95 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py @@ -0,0 +1,71 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], + element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc']) + +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..fef6d457a6528a61613d1295877a2b6b8f80fef5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py @@ -0,0 +1,112 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.e4m3 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E4M3Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Test with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +# +# Add a test for E5M2 +# +dtype = cutlass_cppgen.DataType.e5m2 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E5M2Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..0a002a5fbad80de5f7b29e42db0806469244914c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with mixed operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype =cutlass_cppgen.DataType.f16 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmMixedSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1], + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32) + +# Test with upcast on A +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) + +# Test with upcast on B +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..e226e23684147cb0a9cd5c1270468eb96c67ba15 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0101f78da3b62b599a5deeb89f5596a7e515ce --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffda5b47e37f184c2352f0ee4e737635dbd4147 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py @@ -0,0 +1,423 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from math import prod +import os +import re +import subprocess + +import torch + +from cutlass_library import ( + DataType, + DataTypeSize, + GemmUniversalMode, + LayoutType, + OpcodeClass, + ShortDataTypeNames, + SwizzlingFunctor +) + +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation +from cutlass_cppgen.shape import GemmCoord, MatrixCoord +from cutlass_cppgen.utils.datatypes import torch_type + + +class GemmUniversalLauncher: + def __init__( + self, + operation, + seed=2080, + verification=True, + iterations=500, + compiler_mode= "nvcc", + **kwargs, + ) -> None: + self.math_operation = operation.tile_description.math_instruction.math_operation + self.verification = verification + + if compiler_mode == "nvcc": + compiler.nvcc() + elif compiler_mode == "nvrtc": + compiler.nvrtc() + else: + raise Exception(f"Unexpected compiler string {compiler_mode}") + + op_list = [operation] + if operation.arch < 90: + # Split K via Python is currently only supported for pre-SM90 kernels + self.reduction_operation: ReductionOperation = ReductionOperation( + shape=MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.epilogue_functor.element_epilogue, + epilogue_functor=operation.epilogue_functor, + count=operation.C.alignment, + ) + op_list.append(self.reduction_operation) + + compiler.add_module(op_list, bypass_cache=False) + + self.operation = operation + + self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element) + self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element) + self.dtype_C = torch_type(operation.C.element) + self.dtype_D = torch_type(operation.epilogue_functor.element_output) + + element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) + + if element_size == 1: + self.rand_max = 1 + self.rand_min = 0 + elif element_size <= 8: + self.rand_max = 1 + self.rand_min = -1 + elif element_size == 16: + self.rand_max = 4 + self.rand_min = -4 + else: + self.rand_max = 8 + self.rand_min = -8 + + self.seed = seed + + self.compute_type = operation.epilogue_functor.element_epilogue + self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + + def print_problem_size(self, p, mode, batch_count): + if mode == GemmUniversalMode.Gemm: + mode = "Gemm" + elif mode == GemmUniversalMode.Batched: + mode = "GemmBatched" + elif mode == GemmUniversalMode.GemmSplitKParallel: + mode = "GemmSplitKParallel" + print(f"problem: {p.m}, {p.n}, {p.k}\n batch_count: {batch_count}\n mode: {mode}") + + def uniform_init(self, shape, dtype, layout): + size = prod(shape) + if dtype.is_floating_point: + # Initialize data in FP32 and call convert to the data type we desire. + # This is a workaround for the following error that occurs when attempting to + # call uniform_ on a tensor with torch.float8_e4m3fn data: + # RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn' + data = torch.ceil( + torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_( + self.rand_min - 0.5, self.rand_max - 0.5) + ).to(dtype) + else: + # PyTorch does not currently support integer-typed matrix multiplications on GPU. + # Fall back to CPU for integer type references. + data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1) + + is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1) + + if dtype == torch.float64 or dtype == torch.float32 or is_fp8: + data = data.to("cpu") + + data_ref = data.reshape(shape) + + if layout == LayoutType.RowMajor: + data_cutlass = data_ref + else: + data_cutlass = data_ref.transpose(-1, -2).contiguous() + + data_cutlass = data_cutlass.to("cuda") + + # As of this writing, few operations in PyTorch are supported with FP8 data. + # Thus, we perform computation in FP32 for FP8 reference checks. + if is_fp8: + data_ref = data_ref.to(torch.float32) + + return data_cutlass, data_ref + + def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): + # If any tensor is on CPU, place all tensors on CPU unless only + # tensor C is on CPU + # Handle mixed-input cases by casting to the larger data type and overriding + # to whatever the data type of the larger type is + if self.dtype_A != self.dtype_B: + if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]: + tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device) + else: + tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device) + + devices = [x.device.type for x in [tensor_A, tensor_B]] + if tensor_C is not None: + devices.append(tensor_C.device.type) + + if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]: + device = torch.device("cpu") + else: + device = tensor_A.device + + tensor_A = tensor_A.to(device) + tensor_B = tensor_B.to(device) + if tensor_C is not None: + tensor_C = tensor_C.to(device) + + dtype = torch_type(self.compute_type) + alpha_torch = torch.tensor([alpha], device=device).to(dtype) + beta_torch = torch.tensor([beta], device=device).to(dtype) + + tmp = tensor_A @ tensor_B + tensor_D_ref = (alpha_torch * tmp) + if tensor_C is not None: + tensor_D_ref += (tensor_C * beta_torch) + return tensor_D_ref.to(self.dtype_D) + + def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): + torch.random.manual_seed(self.seed) + + # Assign an actual batch count in cases where we are not running in batched mode. + # This is to differentiate between the number of split K slices and the batch count, + # which are overloaded within the single `batch_count` variable. + if mode == GemmUniversalMode.Batched: + true_batch_count = batch_count + else: + true_batch_count = 1 + + def transpose(layout): + if layout == LayoutType.RowMajor: + return LayoutType.ColumnMajor + else: + return LayoutType.RowMajor + + tensor_A, tensor_A_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.k), + self.dtype_A, + self.operation.A.layout if not self.operation.switched else transpose(self.operation.B.layout), + ) + tensor_B, tensor_B_ref = self.uniform_init( + (true_batch_count, problem_size.k, problem_size.n), + self.dtype_B, + self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout), + ) + if self.dtype_C is not None: + tensor_C, tensor_C_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_C, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + else: + tensor_C = None + tensor_C_ref = None + + tensor_D, _ = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_D, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + tensor_D = torch.zeros_like(tensor_D) + + if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: + alpha = int(alpha) + beta = int(beta) + + # + # Launch kernel + # + + arguments = GemmArguments( + operation=self.operation, + problem_size=problem_size, + A=tensor_A, + B=tensor_B, + C=tensor_C, + D=tensor_D, + output_op=self.operation.epilogue_type(alpha, beta), + gemm_mode=mode, + split_k_slices=split_k_slices, + batch=batch_count, + ) + + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[problem_size.m, problem_size.n], + partitions=split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op=self.reduction_operation.epilogue_type(alpha, beta), + ) + + self.operation.run(arguments) + + if mode == GemmUniversalMode.GemmSplitKParallel: + self.reduction_operation.run(reduction_arguments) + + passed = True + + if self.verification: + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() + else: + arguments.sync() + tensor_D_ref = self.reference( + problem_size, + tensor_A_ref, + tensor_B_ref, + tensor_C_ref, + alpha, + beta, + ) + + tensor_D_ref = tensor_D_ref.to('cuda') + + if self.operation.switched or self.operation.C.layout == LayoutType.ColumnMajor: + tensor_D = tensor_D.transpose(-1, -2).contiguous() + + passed = tensor_D.equal(tensor_D_ref) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size, mode, batch_count) + del arguments + if mode == GemmUniversalMode.GemmSplitKParallel: + del reduction_arguments + + return passed + + +def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"): + passed = True + + minimum_operand_element_size = min( + DataTypeSize[operation.A.element], DataTypeSize[operation.B.element] + ) + opcode_class = operation.tile_description.math_instruction.opcode_class + + if opcode_class == OpcodeClass.Simt: + alignment = 1 + else: + alignment = 128 // minimum_operand_element_size + + alignment_m = alignment + alignment_n = alignment + alignment_k = alignment + + # INT8 alignment constraints + if opcode_class == OpcodeClass.Simt: + A_is_s8 = operation.A.element == DataType.s8 + B_is_s8 = operation.B.element == DataType.s8 + + if A_is_s8 and operation.A.layout == LayoutType.ColumnMajor: + alignment_m = 4 + if B_is_s8 == DataType.s8 and operation.A.layout == LayoutType.RowMajor: + alignment_n = 4 + if A_is_s8 and B_is_s8 and (operation.A.layout == LayoutType.RowMajor or operation.B.layout == LayoutType.ColumnMajor): + alignment_k = 4 + + threadblock_k = operation.tile_description.threadblock_shape[2] + + assert testcase != "interleaved" + + supports_split_k = operation.arch < 90 and not operation.swizzling_functor == SwizzlingFunctor.StreamK + + if testcase == "multistage": + modes = [GemmUniversalMode.Gemm] + problem_size_m = [16, 528] + problem_size_n = [16, 528] + problem_size_k = [ + threadblock_k, + threadblock_k * operation.tile_description.stages + + operation.tile_description.math_instruction.instruction_shape[2], + ] + problem_alpha = [1.0] + problem_beta = [0.0] + batch_counts = [1] + else: + modes = [GemmUniversalMode.Gemm] + batch_counts = [1, 2, 3, 5, 7] + if supports_split_k: + modes.append(GemmUniversalMode.GemmSplitKParallel) + + problem_size_m = [alignment_m, 512 - 3 * alignment_m] + problem_size_n = [alignment_n, 512 - 2 * alignment_n] + if operation.tile_description.stages is None: + stages_for_k_calc = 7 + else: + stages_for_k_calc = operation.tile_description.stages + problem_size_k = [ + alignment_k, + threadblock_k * stages_for_k_calc - alignment_k, + threadblock_k * stages_for_k_calc * 3 - alignment_k, + ] + problem_alpha = [1.0] + problem_beta = [2.0] + + testbed = GemmUniversalLauncher(operation, compiler_mode=compilation_mode) + + for mode in modes: + for m in problem_size_m: + for n in problem_size_n: + for k in problem_size_k: + for batch_count in batch_counts: + for alpha in problem_alpha: + for beta in problem_beta: + # skip very small K problems + if testcase == "universal": + if k // batch_count < 2 * threadblock_k: + continue + + problem_size = GemmCoord(m, n, k) + + if supports_split_k: + split_k_slices = batch_count + else: + split_k_slices = 1 + + overridden_mode = mode + if mode == GemmUniversalMode.Gemm and batch_count > 1: + overridden_mode = GemmUniversalMode.Batched + + passed = testbed.run( + overridden_mode, + problem_size, + batch_count, + split_k_slices, + alpha, + beta, + ) + + if not passed: + return False + + return passed diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5e7467b1e0040ce3012ff8541dfbac381bb861 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'gemm_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28bba3e922961c96df75f8685e3064ab55cbbc87 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py @@ -0,0 +1,260 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_library import SubstituteTemplate + +import cutlass_cppgen +from cutlass_library import ( + DataTypeNames, + EpilogueScheduleSuffixes, + KernelScheduleSuffixes, + LayoutType, + OpcodeClassNames, + ShortDataTypeNames, + ShortLayoutTypeNames +) +from cutlass_cppgen.backend import library + +from gemm_testbed import test_all_gemm + + +class Layout: + """ + Utility class to map transpose and non-transpose terminology to row- and column-major terminology + """ + + T = LayoutType.RowMajor + N = LayoutType.ColumnMajor + + +class LayoutCombination: + """ + Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs + """ + + NNN = (Layout.N, Layout.N, Layout.N) + NNT = (Layout.N, Layout.N, Layout.T) + NTN = (Layout.N, Layout.T, Layout.N) + NTT = (Layout.N, Layout.T, Layout.T) + TNN = (Layout.T, Layout.N, Layout.N) + TNT = (Layout.T, Layout.N, Layout.T) + TTN = (Layout.T, Layout.T, Layout.N) + TTT = (Layout.T, Layout.T, Layout.T) + + +def get_name( + layouts, + alignments, + element_output, + element_accumulator, + element_epilogue, + cluster_shape, + threadblock_shape, + stages, + element_a, + element_b, + element_c, + arch, + opclass, + kernel_schedule=None, + epilogue_schedule=None, + suffix="", +): + """ + Generates a procedural name for a test case. + + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alignments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param arch: compute capability of kernel being generated + :type arch: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param kernel_schedule: kernel_schedule type + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue_schedule type + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param suffix: additional string to add to the suffix of the name + :type suffix: str + + :return: str + """ + name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "eA": DataTypeNames[element_a], + "eB": DataTypeNames[element_b], + "eC": DataTypeNames[element_c], + "lA": ShortLayoutTypeNames[layouts[0]], + "lB": ShortLayoutTypeNames[layouts[1]], + "lC": ShortLayoutTypeNames[layouts[2]], + "opclass": OpcodeClassNames[opclass], + "acc": DataTypeNames[element_accumulator], + "cM": str(cluster_shape[0]), + "cN": str(cluster_shape[1]), + "cK": str(cluster_shape[2]), + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "stages": str(stages) if stages is not None else "auto", + "aA": str(alignments[0]), + "aB": str(alignments[1]), + "aC": str(alignments[2]), + "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule], + "e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule], + "suffix": "" if suffix is None else suffix, + }, + ) + + +def add_test_gemm( + cls=None, + cc=None, + element=None, + layouts=None, + alignments=None, + element_output=None, + element_accumulator=None, + cluster_shape=None, + threadblock_shape=None, + warp_count=None, + stages=None, + opclass=None, + swizzle=None, + kernel_schedule=None, + epilogue_schedule=None, + compilation_modes=['nvcc', 'nvrtc'], + element_A=None, + element_B=None, + element_C=None): + """ + Create test-running functions with the given specification and set it as a method of ``cls``. + + :param cls: class to which the generated method will be added + :type cls: type + :param cc: compute capability to compile for + :type cc: int + :param element: data type of A and B operands + :type element: cutlass_cppgen.DataType.f16 + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass_cppgen.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass_cppgen.DataType + :param cluster_shape: dimensions of clusters + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param swizzle: threadblock swizzling functor + :param kernel_schedule: kernel schedule to use + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue schedule to use + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') + :type compilation_modes: list, + :param element_A: data type of operand A. If set, overrides ``element`` + :type element_A: cutlass_cppgen.DataType + :param element_B: data type of operand B. If set, overrides ``element`` + :type element_B: cutlass_cppgen.DataType + :param element_C: data type of operand C. If set, overrides ``element`` + :type element_C: cutlass_cppgen.DataType + """ + + if element_A is None: + element_A = element + if element_B is None: + element_B = element + if element_C is None: + element_C = element + if element_output is None: + element_output = element + if element_accumulator is None: + element_accumulator = element + + for compilation_mode in compilation_modes: + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + + td = plan.tile_descriptions()[0] + + if warp_count is not None: + td.warp_count = warp_count + td.threadblock_shape = threadblock_shape + td.stages = stages + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode)) + + element_epilogue = element_accumulator + name = get_name( + layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator, + element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape, + stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass, + kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}') + + setattr(cls, name, run) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py new file mode 100644 index 0000000000000000000000000000000000000000..f550c394812c7fede55070e4c99c4471a69c2f88 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py @@ -0,0 +1,57 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests for a successful installation of the CUTLASS Python interface +""" + +import os +import unittest + +import cutlass_cppgen +import cutlass_library + + +class InstallationTest(unittest.TestCase): + def test_cutlass_source_paths(self): + """ + Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages + """ + src_file = 'include/cutlass/cutlass.h' + library_file = os.path.join(cutlass_library.source_path, src_file) + cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file) + assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded." + assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded." + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5d46d45d617198a46bec85cd7218cb5431a7b1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py @@ -0,0 +1,284 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level Conv2d interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException +import os + + +class Conv2dEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Conv2d interface + """ + def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator, + alignment_A, alignment_B, alignment_C): + + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + + self.conv_kind = conv_kind + + self.plan = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator) + + self.op = self.plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default Conv2d + :type other_plan: cutlass_cppgen.op.Conv2d + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types + and layouts for constructing the Conv2d interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors using generic element and output + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator are equal + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test with only the generic types. Only rune if the types of A, B, C, and D are the same + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator): + plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + size = (2, 2) + A = np.zeros(size, dtype=type_A) + B = np.zeros(size, dtype=type_B) + C = np.zeros(size, dtype=type_C) + D = np.zeros(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def torch_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend + """ + if not datatypes.is_torch_available(): + return + + import torch + type_A = datatypes.torch_type(self.element_A) + type_B = datatypes.torch_type(self.element_B) + type_C = datatypes.torch_type(self.element_C) + type_D = datatypes.torch_type(self.element_D) + type_accum = datatypes.torch_type(self.element_accumulator) + + size = (2, 2) + + A = torch.empty(size, dtype=type_A) + B = torch.empty(size, dtype=type_B) + C = torch.empty(size, dtype=type_C) + D = torch.empty(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + if type_A == type_B: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + self.torch_test() + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class ConvEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Conv2d interface + """ + pass + +type2alignment = { + cutlass_cppgen.DataType.f16: 8, + cutlass_cppgen.DataType.f32: 4 +} + +def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): + + test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}" + + def run(self): + conv2d_eq = Conv2dEquivalence( + conv_kind=conv_kind, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + element_accumulator=element_accumulator, + alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B], + alignment_C=type2alignment[element_C] + ) + conv2d_eq.test_all() + + setattr(ConvEquivalenceTest, test_name, run) + +for conv_kind in ["fprop", "wgrad", "dgrad"]: + for types in [ + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32] + ]: + add_test(conv_kind, types[0], types[1], types[2], types[3], types[4]) + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class Conv2dErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + td = plan.tile_descriptions()[0] + td.threadblock_shape=[17, 32, 5] + + plan.tile_description = td + with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'): + plan.compile() + # Clean up the error message + os.remove("./cutlass_python_compilation_device_error.txt") + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d67f4d07f01b0936ff5796bfb6fe4c98b5c031 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py @@ -0,0 +1,254 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Test the EVT interface +""" + +import numpy as np +import unittest + +import cutlass_cppgen +from cutlass_cppgen import LayoutType, Tensor +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import reshape, permute + +from utils import ExpectException + + +@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") +class EVTErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the EVT interface + """ + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT requires root node be 'D'") + def test_root_not_d(self): + """ + Test when "D" does not exist in Sm90 EVT + """ + def evt_root_not_d(accum, alpha): + F = accum * alpha + return F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(device_cc() == 90, + "SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, " + "but the variable 'D' is not found in the return values.", True): + + cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors) + + def test_no_accum(self): + """ + Test when "accum" is not in input arguments + """ + def evt_no_accum(alpha, C): + D = alpha * C + return D + + example_tensors = { + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True): + cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors) + + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size") + def test_too_much_shared_memory(self): + """ + Test when the epilogue consumes too much shared memory + """ + def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8): + D1 = accum + C1 + D2 = D1 + C2 + D3 = D2 + C3 + D4 = D3 + C4 + D5 = D4 + C5 + D6 = D5 + C6 + D7 = D6 + C7 + D = D7 + C8 + return D, D1, D2, D3, D4, D5, D6, D7 + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C1": self.fake_tensor(np.float16, (6, 512, 512)), + "C2": self.fake_tensor(np.float16, (6, 512, 512)), + "C3": self.fake_tensor(np.float16, (6, 512, 512)), + "C4": self.fake_tensor(np.float16, (6, 512, 512)), + "C5": self.fake_tensor(np.float16, (6, 512, 512)), + "C6": self.fake_tensor(np.float16, (6, 512, 512)), + "C7": self.fake_tensor(np.float16, (6, 512, 512)), + "C8": self.fake_tensor(np.float16, (6, 512, 512)), + "D1": self.fake_tensor(np.float16, (6, 512, 512)), + "D2": self.fake_tensor(np.float16, (6, 512, 512)), + "D3": self.fake_tensor(np.float16, (6, 512, 512)), + "D4": self.fake_tensor(np.float16, (6, 512, 512)), + "D5": self.fake_tensor(np.float16, (6, 512, 512)), + "D6": self.fake_tensor(np.float16, (6, 512, 512)), + "D7": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors) + + plan = cutlass_cppgen.op.Gemm( + element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor, + element_accumulator=np.float32 + ) + + with ExpectException(True, + "RuntimeError: The epilogue consumes too much shared memory. " + "No valid tile description is found in the generator.", True): + plan.epilogue_visitor = epilogue_visitor + + def test_not_ssa(self): + """ + Test when the epilogue is not in SSA + """ + def evt_redefine(accum, C, alpha): + F = accum + C + F = F * alpha + D = F + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True): + cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors) + + def evt_undefine(accum, alpha): + F = accum + C + D = F * alpha + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True): + cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors) + + def test_missing_example_tensor(self): + """ + Test when the example tensor of an input/output variable is not provided + """ + def evt_missing_example_tensor(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for D is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for C is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + def test_return_expression(self): + """ + Test when the return value is an expression + """ + def evt_return_expr(accum, C): + return accum + C + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "SyntaxError: Return value cannot be an expression", True): + cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors) + + def test_incompatible_shape(self): + """ + Test when the shape of example tensors are incompatible + """ + def evt_incompatible_shape(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 256, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, + "RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True): + cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors) + + def test_no_matching_impl(self): + def evt_no_matching_impl(accum, bias): + D = accum + reshape(permute(bias, indices=(1, 0)), new_shape=(512, 1)) + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 256)), + "bias": self.fake_tensor(np.float16, (16, 32)), + "D": self.fake_tensor(np.float16, (6, 512, 256)) + } + + with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True): + cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors) + # + # Helper functions + # + + def fake_tensor(self, element, shape): + return Tensor(element=element, shape=shape, layout_tag=LayoutType.RowMajor) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2913d5933f5342cc58b4f252657a724d2c7692da --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py @@ -0,0 +1,354 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level GEMM interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException + + +class GemmEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Gemm interface + """ + def __init__(self, element_A, element_B, element_C, element_D, element_accumulator, + layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C): + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) + self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default GEMM + :type other_plan: cutlass_cppgen.op.Gemm + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) + + # Compare whether the operations are equal by comparing the C++ code that would be emitted for them + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types + and layouts for constructing the Gemm interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_B=self.layout_B, layout_C=self.layout_C, + element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if self.element_A == self.element_B and self.layout_A == self.layout_B: + plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, + layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + layout_to_order = { + cutlass_cppgen.LayoutType.RowMajor: 'C', + cutlass_cppgen.LayoutType.ColumnMajor: 'F' + } + size = (2, 2) + A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) + B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B) + C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C) + D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) + + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if type_A == type_B and self.layout_A == self.layout_B: + plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + + +class GemmEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Gemm interface + """ + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, + layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") + def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64, + element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=1, alignment_B=1, alignment_C=1) + gemm_eq.test_all() + + +class GemmErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + + with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) + + def test_tensorop_availability(self): + """ + Tests case in which only SIMT operations are available but TensorOp is requested + """ + cc = device_cc() + + # F64 Tensor Core operations are only avaiable on certain devices + supports_tensorop_f64 = cc in [80, 89, 90] + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) + + error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' + with ExpectException(not supports_tensorop_f64, error_msg): + plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp + + expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt + assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") + def test_opclass_switch(self): + """ + Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) + """ + plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp + + # Ensure that all tile descriptions have opclass of TensorOp + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp + + plan.opclass = cutlass_cppgen.OpcodeClass.Simt + + # Ensure that all tile descriptions have opclass of Simt + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + cc = device_cc() + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + td = plan.tile_descriptions()[0] + stages = td.stages + + # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage + # count should be used + with ExpectException(cc < 90, f'Requested zero stages'): + td.stages = 0 + plan.construct(td) + + if cc < 90: + with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): + td.stages = 3 + plan.construct(td) + elif cc == 90: + original_kschedule = td.kernel_schedule + original_eschedule = td.epilogue_schedule + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized + td.stages = 3 + plan.construct(td) + # Reset schedules + td.kernel_schedule = original_kschedule + td.epilogue_schedule = original_eschedule + elif cc in [100, 101, 103]: + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.stages = 3 + plan.construct(td) + + with ExpectException(True, f'Requested too many stages'): + td.stages = 100 + plan.construct(td) + + # Reset stage count + td.stages = stages + + cluster_shape = td.cluster_shape + with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'): + td.cluster_shape = [2, 1, 1] + plan.construct(td) + + # Reset cluster shape + td.cluster_shape = cluster_shape + + with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto + plan.construct(td) + + with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative + td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK + plan.construct(td) + + # Ensure that all returned tile descriptions are unique + ops = {} + for i, td in enumerate(plan.tile_descriptions()): + op = plan.construct(td) + code_str = op.rt_module.emit() + if code_str in ops: + conflicting_td = ops[code_str] + assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}' + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9f93ca26e2d79a15dab4dd0045836ebd9fe62757 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py @@ -0,0 +1,69 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Helper functions & classes for interface test +""" +class ExpectException: + """ + Utility class to assert that an exception was raised when expected + + Example: + + .. highlight:: python + .. code-block:: python + + with ExceptionExpected(True, 'Division by zero'): + x = 1.0 / 0.0 + + :param exception_expected: whether an exception is expected to be raised + :type exception_expected: bool + :param message: message to print if an exception is raised when not expected or vice versa + :type message: str + """ + def __init__(self, exception_expected: bool, message: str = '', verify_msg=False): + self.exception_expected = exception_expected + self.message = message + self.verify_msg = verify_msg + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + exception_raised = exc_type is not None + assert self.exception_expected == exception_raised, self.message + if self.verify_msg: + exc_message = f"{exc_type.__name__}: {exc_val}" + assert exc_message == self.message, f"expect error message {self.message}, got {exc_message}" + + # Suppress the exception + return True diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cdc421ccffffeb7bd1696aaf9916330a6625ca --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility script for discovering and running all PyCuTe tests +""" + +import argparse +import logging +import pathlib +import unittest + + +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level into the numeric identifier used + in setting the log level + + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {log_level}") + return numeric_level + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, + help='Logging level to be used by the generator script') + args = parser.parse_args() + + # Set the logging level based on the user-provided `--log-level` command-line option + logging.basicConfig(level=args.log_level) + + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, "test_*.py") + test_runner = unittest.runner.TextTestRunner() + results = test_runner.run(tests) + if not results.wasSuccessful(): + raise Exception("Test cases failed") diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py new file mode 100644 index 0000000000000000000000000000000000000000..d4330377cab7079ea16422f194ddf4f2403ea507 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py @@ -0,0 +1,95 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.coalesce +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestCoalesce(unittest.TestCase): + def helper_test_coalesce(self, layout): + layoutR = coalesce(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + self.assertEqual(size(layoutR), size(layout)) + + for i in range(size(layout)): + self.assertEqual(layoutR(i), layout(i)) + + def test_coalesce(self): + layout = Layout(1,0) + self.helper_test_coalesce(layout) + + layout = Layout(1,1) + self.helper_test_coalesce(layout) + + layout = Layout((2,4)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (1,6,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (1,7,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (4,7,8)) + self.helper_test_coalesce(layout) + + layout = Layout((2,(4,6))) + self.helper_test_coalesce(layout) + + layout = Layout((2,4), (4,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (24,6,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,3), (2,4,4)) + self.helper_test_coalesce(layout) + + layout = Layout(((2,2),(2,2)), ((1,4),(8,32))) + self.helper_test_coalesce(layout) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8684a55b19c90eae11ddd1cca011c2ff8270b5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.complement +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComplement(unittest.TestCase): + def helper_test_complement(self, layout): + layoutR = complement(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + # Post-condition: test disjointness of the codomains + for a in range(size(layout)): + for b in range(size(layoutR)): + assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0) + + def test_complement(self): + test = Layout(1,0) + self.helper_test_complement(test) + + test = Layout(1,1) + self.helper_test_complement(test) + + test = Layout(4,0) + self.helper_test_complement(test) + + test = Layout((2,4),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,3),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,4),(1,4)) + self.helper_test_complement(test) + + test = Layout((2,4,8),(8,1,64)) + self.helper_test_complement(test) + + test = Layout(((2,2),(2,2)),((1,4),(8,32))) + self.helper_test_complement(test) + + test = Layout((2,(3,4)),(3,(1,6))) + self.helper_test_complement(test) + + test = Layout((4,6),(1,6)) + self.helper_test_complement(test) + + test = Layout((4,10),(1,10)) + self.helper_test_complement(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py new file mode 100644 index 0000000000000000000000000000000000000000..6c27eb7fe6cbb7bbbea7bd644ac8e64a2fc853c9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py @@ -0,0 +1,213 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.composition +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComposition(unittest.TestCase): + def helper_test_composition(self, layoutA, layoutB): + layoutR = composition(layoutA, layoutB) + + _LOGGER.debug(f"{layoutA} o {layoutB} => {layoutR}") + + # True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR. + + # Test that R(c) = A(B(c)) for all coordinates c in layoutR + for i in range(size(layoutR)): + self.assertEqual(layoutR(i), layoutA(layoutB(i))) + + def test_composition(self): + layoutA = Layout(1,0) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,0) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((1), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((1), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((2,3), (2,4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8), (8,1)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + layoutB = Layout(8, 4) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((4,2)), ((1,16))) + layoutB = Layout((4,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((2,2), (2,1)) + layoutB = Layout((2,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2)) + layoutB = Layout((2,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((2,2,2), (1,8,2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((4,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + # Pre-coalesced LHS + layoutA = Layout((4,6,8),(1,4,7)) + layoutB = Layout((6),(1)) + self.helper_test_composition(layoutA, layoutB) + + # Mid-layout truncation + layoutA = Layout((4,6,8,10),(2,3,5,7)) + layoutB = Layout(6,12) + self.helper_test_composition(layoutA, layoutB) + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbf443c9725735b0051d0a225a55eece9c663a8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py @@ -0,0 +1,80 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.int_tuple +""" + +import unittest + +from pycute import * + + +class TestIntTuple(unittest.TestCase): + def test_product(self): + self.assertEqual(product(2), 2) + + self.assertEqual(product((3,2)), 6) + + self.assertEqual(product(product(((2,3),4))), 24) + + def test_inner_product(self): + self.assertEqual(inner_product(2, 3), 6) + + self.assertEqual(inner_product((1,2), (3,2)), 7) + + self.assertEqual(inner_product(((2,3),4), ((2,1),2)), 15) + + def test_shape_div(self): + self.assertEqual(shape_div((3,4), 6), (1,2)) + + self.assertEqual(shape_div((3,4), 12), (1,1)) + + self.assertEqual(shape_div((3,4), 36), (1,1)) + + self.assertEqual(shape_div(((3,4),6), 36), ((1,1),2)) + + self.assertEqual(shape_div((6,(3,4)), 36), (1,(1,2))) + + def test_prefix_product(self): + self.assertEqual(prefix_product(2), 1) + + self.assertEqual(prefix_product((3,2)), (1,3)) + + self.assertEqual(prefix_product((3,2,4)), (1,3,6)) + + self.assertEqual(prefix_product(((2,3),4)), ((1,2),6)) + + self.assertEqual(prefix_product(((2,3),(2, 1, 2),( 5, 2, 1))), + ((1,2),(6,12,12),(24,120,240))) + + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6501fd6c7c6fc5a518e4d22bf93dc0e4746a8ba --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py @@ -0,0 +1,87 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestLeftInverse(unittest.TestCase): + def helper_test_left_inverse(self, layout): + inv_layout = left_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(layout)): + self.assertEqual(inv_layout(layout(i)), i) + + def test_left_inverse(self): + test = Layout(1,0) + self.helper_test_left_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_left_inverse(test) + + test = Layout(1,1) + self.helper_test_left_inverse(test) + + test = Layout(4,1) + self.helper_test_left_inverse(test) + + test = Layout(4,2) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_left_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_left_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed9759d7808da8087fe9c76761d2dd9eaeab08b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py @@ -0,0 +1,96 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestRightInverse(unittest.TestCase): + def helper_test_right_inverse(self, layout): + inv_layout = right_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(inv_layout)): + self.assertEqual(layout(inv_layout(i)), i) + + def test_right_inverse(self): + test = Layout(1,0) + self.helper_test_right_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout((3,7),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout(1,1) + self.helper_test_right_inverse(test) + + test = Layout(4,0) + self.helper_test_right_inverse(test) + + test = Layout(4,1) + self.helper_test_right_inverse(test) + + test = Layout(4,2) + self.helper_test_right_inverse(test) + + test = Layout((2,4),(0,2)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_right_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_right_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb99a4833529e18fa22d65a235ce80dad372365 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py @@ -0,0 +1,59 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.typing +""" + +import logging +import unittest +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestTyping(unittest.TestCase): + def helper_test_typing(self, _cls, _obj, cls, expected: bool): + _LOGGER.debug(f"issubclass({_cls}, {cls})") + _LOGGER.debug(f"isinstance({_obj}, {cls})") + + self.assertEqual(expected, issubclass(_cls, cls)) + self.assertEqual(expected, isinstance(_obj, cls)) + + def test_typing(self): + self.helper_test_typing(int, 1, Integer, True) + self.helper_test_typing(float, 1., Integer, False) + self.helper_test_typing(str, 'hi', Integer, False) + self.helper_test_typing(bool, False, Integer, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h new file mode 100644 index 0000000000000000000000000000000000000000..86b7823785a9f2a957cf505740d6cfde45ccfef1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */ + +#pragma nv_diag_suppress boolean_controlling_expr_is_constant +#include +#pragma nv_diag_warning boolean_controlling_expr_is_constant +#pragma warning( disable : 4503) + +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gets a CUDA device +cudaDeviceProp GetCudaDevice(); + +/// Prints device properties +std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Sets flags for Unit test +void FilterArchitecture(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order +// of problem sizes run by CUTLASS unit tests +int CutlassUnitTestProblemCount(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// active test macro +#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ + +// disabled test macro +#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {} + +#if CUTLASS_TEST_LEVEL == 0 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#elif CUTLASS_TEST_LEVEL == 1 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#else +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#endif + +#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) +#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12) + #define CUDA_12_0_SM90_FEATURES_SUPPORTED true +#else + #define CUDA_12_0_SM90_FEATURES_SUPPORTED false +#endif + +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h new file mode 100644 index 0000000000000000000000000000000000000000..3035e9862bcb79b749b4cbc4a74341bceac9c598 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h @@ -0,0 +1,907 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helper to construct cached name for +*/ +#pragma once + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "thrust/universal_vector.h" + +#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS +#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result of a test +struct CachedTestKey { + + std::string op; ///< Concatenated string representation of operation performed + std::string problem; ///< Concatenated string representation of problem description + std::string types; ///< Concatenated string representation of operand types + uint32_t A; ///< Hashed result of tensor A + uint32_t B; ///< Hashed result of tensor B + uint32_t C; ///< Hashed result of tensor C + + // + // Methods + // + inline CachedTestKey(): A(), B(), C() { } + + inline CachedTestKey( + std::string op, ///< Concatenated string representation of operation performed + std::string problem, ///< Concatenated string representation of problem description + std::string types, ///< Concatenated string representation of operand types + uint32_t A, ///< Hashed result of tensor A + uint32_t B, ///< Hashed result of tensor B + uint32_t C ///< Hashed result of tensor C + ): + op(op), problem(problem), types(types), A(A), B(B), C(C) + { } + + /// Checks for equality of the problem + bool operator==(CachedTestKey const &rhs) const { + return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestKey &result) { + + in >> result.op; + in >> result.problem; + in >> result.types; + in >> result.A; + in >> result.B; + in >> result.C; + + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) { + + out << result.op << " "; + out << result.problem << " "; + out << result.types << " "; + out << result.A << " "; + out << result.B << " "; + out << result.C << " "; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResult { + uint32_t D; + // + // Methods + // + + CachedTestResult(): D() + { } + + CachedTestResult(uint32_t D): D(D) + { } + + operator bool() const { + return bool(D); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestResult &result) { + in >> result.D; + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) { + out << result.D; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResultListing { + + std::list> results; + + // + // Methods + // + + inline CachedTestResultListing(std::string const &path) { + std::ifstream file(path); + + while (file.good()) { + CachedTestKey key; + file >> key; + + CachedTestResult result; + file >> result; + + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + } + + /// Returns the cached result + std::pair find(CachedTestKey const &rhs) const { + for (auto const & result : results) { + if (result.first == rhs) { + return std::make_pair(true, result.second); + } + } + return std::make_pair(false, CachedTestResult()); + } + + /// Appends an entry + void append(CachedTestKey const &key, CachedTestResult const &result) { + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + + /// Writes the entire listing to a file + bool write(std::string const &path) { + std::ofstream file(path); + if (!file.good()) { + return false; + } + + for (auto const &result : results) { + file << result.first << result.second << std::endl; + } + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ScalarEncoder { + Element scalar; + + ScalarEncoder(Element s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + Element s = scalar; + if (s < Element()) { + s = -s; + ss << "n"; + } + ss << s; + return ss.str(); + } +}; + +template +ScalarEncoder EncodeScalar(Element a) { + return ScalarEncoder(a); +} + +template +struct ScalarEncoder> { + cutlass::complex scalar; + + ScalarEncoder(cutlass::complex s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + ss << EncodeScalar(scalar.real()) << "_" << EncodeScalar(scalar.imag()) << "i"; + return ss.str(); + } +}; + +template +std::ostream &operator<<(std::ostream &out, ScalarEncoder const &scalar) { + out << scalar.str(); + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { + switch (conv_op) { + case cutlass::conv::Operator::kFprop: return "fprop"; + case cutlass::conv::Operator::kDgrad: return "dgrad"; + case cutlass::conv::Operator::kWgrad: return "wgrad"; + case cutlass::conv::Operator::kDeconv: return "deconv"; + } + return "conv_unknown"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode GemmCoord (Gemm problem size) +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::gemm::GemmCoord const &problem) { + + out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode Conv2dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv2dProblemSize const &problem) { + + out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode Conv3dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv3dProblemSize const &problem) { + + out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode 3.x ConvNd ProblemShape +template +inline std::ostream &EncodeProblemSize( + std::ostream &out, + ProblemShape const& problem_shape) { + + out << problem_shape.shape_A << "_"; + out << problem_shape.shape_B << "_"; + + out << "padl" << problem_shape.lower_padding << "_"; + out << "padu" << problem_shape.upper_padding << "_"; + out << "str" << problem_shape.traversal_stride << "_"; + out << "dil" << problem_shape.dilation << "_"; + + switch (problem_shape.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string ElementTypeName() { + return std::string(typeid(Element).name()); +} + +template <> +inline std::string ElementTypeName() { + return "h"; +} + +template <> +inline std::string ElementTypeName>() { + return "ch"; +} + +template <> +inline std::string ElementTypeName() { + return "bf16"; +} + +template <> +inline std::string ElementTypeName>() { + return "cbf16"; +} + +template <> +inline std::string ElementTypeName() { + return "tf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "ctf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "c"; +} + +template <> +inline std::string ElementTypeName>() { + return "z"; +} + +template <> +inline std::string ElementTypeName>() { + return "q"; +} + +template <> +inline std::string ElementTypeName() { + return "s8"; +} + +template <> +inline std::string ElementTypeName() { + return "u8"; +} + +template <> +inline std::string ElementTypeName() { + return "s4"; +} + +template <> +inline std::string ElementTypeName() { + return "u4"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string LayoutTypeName() { + return std::string(typeid(Layout).name()); +} + +template <> +inline std::string LayoutTypeName() { + return "n"; +} + +template <> +inline std::string LayoutTypeName() { + return "t"; +} + +template <> +inline std::string LayoutTypeName() { + return "nhwc"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc32hw32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc64hw64"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c32rsk32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c64rsk64"; +} + +template <> +inline std::string LayoutTypeName() { + return "ndhwc"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName() << LayoutTypeName(); + return ss.str(); +} + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName(); + return ss.str(); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function on a byte array +struct CRC32 { + + uint32_t table[256]; + + // + // Methods + // + + CRC32() { + + uint32_t rem; + int i, j; + + for (i = 0; i < 256; i++) { + rem = i; + for (j = 0; j < 8; j++) { + if (rem & 1) { + rem >>= 1; + rem ^= 0xedb88320; + } else + rem >>= 1; + } + table[i] = rem; + } + } + + /// Computes the CRC of an array of bytes + uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const { + uint8_t const *p = static_cast(start); + uint8_t const *q = static_cast(start) + length; + + crc = ~crc; + + for (; p != q; ++p) { + uint8_t octet = *p; + crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet]; + } + + return ~crc; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Element, typename Layout +> +uint32_t TensorHash( + cutlass::TensorView view, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); +} + +template +uint32_t TensorHash( + thrust::universal_vector& tensor, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits::value / 8, crc); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName() << "_" + << ElementTypeName(); + + return out; +} + +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName(); + + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedGemmTestKey( + cutlass::gemm::GemmCoord const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode gemm operator and problem sizes + key.op = "gemm"; + + std::stringstream ss_problem; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_broadcast"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithReductionTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_reduction"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv3dTestKey( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv3dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv3d operator and problem sizes + key.op = "conv3d"; + + std::stringstream ss_problem; + + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline CachedTestKey CreateCachedConvNd3xTestKey( + cutlass::conv::Operator conv_operator, + ProblemShape const& problem_shape, + double alpha, + double beta, + thrust::universal_vector A, + thrust::universal_vector B, + thrust::universal_vector C +) { + + CachedTestKey key; + + // Encode convNd operator and problem sizes + std::stringstream ss_op; + ss_op << "conv" << ProblemShape::RankS << "d"; + key.op = ss_op.str(); + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem_shape); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, + ElementB, + ElementC, + ElementD>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..a14134b2854732e669977831207a456d28beed9f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h @@ -0,0 +1,927 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv2dProblemVector = std::vector; + +// +// Structures to prune items from Conv2dProblemVector +// +// Specification template for pruning items for convolution problem lists +template struct Specification +{ + virtual ~Specification() = default; + virtual bool is_satisfied(T item) const = 0; +}; + +// input size (NHWC) specification +struct InputSizeSpecification : Specification +{ + cutlass::Tensor4DCoord input_size; + + InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C)); + } +}; + +// stride (stride_h, stride_w) specification +struct StrideSpecification : Specification +{ + cutlass::MatrixCoord stride; + + StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h)); + } +}; + +// channel (C,K) specification, must be multiple of minimum channel +struct ChannelDivisibilitySpecification : Specification +{ + int channel_multiple; + + ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0)); + } +}; + +// +// Pruning function for items from Conv2dProblemVector based on a Specification +// +inline Conv2dProblemVector prune(Conv2dProblemVector const &items, + Specification const &spec) +{ + Conv2dProblemVector pruned_list; + + for (auto& p : items) + if (spec.is_satisfied(p)) + pruned_list.push_back(p); + return pruned_list; +} + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv2dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv2dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + + Conv2dProblemVector conv2d_default_sizes; + Conv2dProblemVector conv2d_rigorous_sizes; + Conv2dProblemVector conv2d_resnet50_sizes; + Conv2dProblemVector conv2d_resnet50_sizes_perf; + + // + // Methods + // + /// Default ctor + TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + initialize_conv2d_default_sizes(); + initialize_conv2d_rigorous_sizes(); + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/); + + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &conv2d_default_sizes, + &conv2d_rigorous_sizes, + &conv2d_resnet50_sizes, + &conv2d_resnet50_sizes_perf + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) asymmetric paddings (1, 0, 1, 0) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (2,2) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 11, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 17, 19, minimum_channel_size}, // input size (NHWC) + {16, 2, 2, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 5, minimum_channel_size}, // input size (NHWC) + {16, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 17, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 20, 24, 8}, // input size (NHWC) + {40, 3, 3, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 160}, // input size (NHWC) + {224, 1, 1, 160}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 19, 37, 160}, // input size (NHWC) + {224, 3, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, 160}, // input size (NHWC) + {224, 2, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 128}, // input size (NHWC) + {224, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 29, 37, 160}, // input size (NHWC) + {224, 5, 5, 160}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 16, 288}, // input size (NHWC) + {160, 5, 5, 288}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 55, 51, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 71, 80, 32}, // input size (NHWC) + {64, 5, 5, 32}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 224, 224, 8}, // input size (NHWC) + {64, 7, 7, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size stride (3, 3), filter (3, 3), non-default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 23, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size padding > stride, asymmetric filter, padding and striding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 31, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 4}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 35, 256}, // input size (NHWC) + {512, 7, 5, 256}, // filter size (KRSC) + {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 5}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size *mixed* stride (1, 2) and (2, 1), + // filter (3, 3), default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + ///////////////////////////////////////////////////////////////////////////// + // Additional input size + ///////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 28, 28, 256}, // input size (NHWC) + {256, 2, 2, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 32, 32, 16}, // input size (NHWC) + {32, 3, 3, 16}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {6, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {32, 24, 32, 32}, // input size (NHWC) + {32, 1, 2, 32}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 4, 5, 128}, // input size (NHWC) + {256, 3, 6, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 3, 3, 256} // output size (NPQK) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 2, 3, 256}, // input size (NHWC) + {328, 3, 5, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 1, 1, 328} // output size (NPQK) + )); + } + + + // Add a few large and rigorous convolution problem sizes + void initialize_conv2d_rigorous_sizes() { + +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 124, 224, 96}, // input size (NHWC) + {24, 7, 7, 96}, // filter size (KRSC) + {1, 229, 129, 32} // output size (NPQK) + )); + + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 233, 35, 48}, // input size (NHWC) + {24, 7, 5, 48}, // filter size (KRSC) + {1, 233, 35, 24} // output size (NPQK) + )); + +#endif + + } + + + // Add resent50 layers to unit testing sizes + void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){ + +#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + [1, 224, 224, 3], // input size (NHWC) + [64, 7, 7, 3], // filter size (KRSC) + [3, 3, 3, 3], // padding (pad_h, _, pad_w, _) + [2, 2], // stride (stride_h, stride_w) + [1, 1], // dilation (dilation_h, dilation_w) + )); +#endif + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {256, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 3, 3, 64}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {64, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {128, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {128, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {512, 1, 1, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {128, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {1024, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {256, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {256, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {1024, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {256, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {2048, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {512, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {512, 3, 3, 512}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {2048, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 2048}, // input size (NHWC) + {512, 1, 1, 2048}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + } + +}; + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedGroupConv2dProblemSizes { + + // + // Data members + // + int threadblock_n; + int threadblock_k; + int minimum_channel_size; + + Conv2dProblemVector default_single_group_sizes; + Conv2dProblemVector default_multiple_group_sizes; + + // + // Methods + // + /// Default ctor + TestbedGroupConv2dProblemSizes( + int threadblock_n_, + int threadblock_k_, + int minimum_channel_size_ = 64) + : threadblock_n (threadblock_n_), + threadblock_k (threadblock_k_), + minimum_channel_size (minimum_channel_size_) { + initialize_group_conv2d_default_sizes(); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &default_single_group_sizes, + &default_multiple_group_sizes + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!((problem.C / problem.groups) % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_group_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////// + // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + // One CTA calculates a single group + //////////////////////////////////////////////////////////////////////////////////// + + for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { + // groups = 2, 3, 4 + for (int groups = 2; groups < 5; ++groups) { + + int conv_k = cta_per_group_k * threadblock_n * groups; + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) + {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + groups // groups + )); + + } // loop groups + } // loop cta_per_group_k + + // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // Larger problem sizes + + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 696}, // input size (NHWC) + {768, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 14, 14, 1392}, // input size (NHWC) + {1536, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + + //////////////////////////////////////////////////////////////////////////////////// + // One CTA calculate multiple groups: CTA::N % k_per_group = 0 + //////////////////////////////////////////////////////////////////////////////////// + + // 2 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 4}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 2 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 4 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 8}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + + // 4 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + } + +}; + + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..34588ecb467b824cc0fcbbff0bc0d99e4385d80e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + +public: + + TestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificConv2d( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) + std::vector problem_vectors = { + conv_test_sizes, // run user specified sizes + conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Flatten 2D problem_vectors into a 1D problem_sizes + std::vector problem_sizes; + for (auto problem_vector : problem_vectors) { + for(auto conv_problem : problem_vector) { + problem_sizes.push_back(conv_problem); + } + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) + // run the most rigorous problem size first + if (CutlassUnitTestProblemCount()) { + std::reverse(problem_sizes.begin(), problem_sizes.end()); + } + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // Fixed channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // Few channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + // Although strided dgrad works for all stride combinations, we are only going + // to run strided dgrad for non-unity strides + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + + return true; + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + + return true; + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}) // dilation (dilation_h, dilation_w) + .reset_split_k_slices(2), + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..cf075674da673cf8e056172732f912b8acba3c5b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/host_reorder.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class InterleavedTestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_B_reordered; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + InterleavedTestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 3; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + cutlass::reorder_convK( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "ncxhwx_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_cxrskx_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllInterleavedConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + InterleavedTestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + +#if 0 + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } +#endif + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..ad7b2ce61a66a79f852c0aac0895d10ba18e5466 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level Conv2Ds with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "conv2d_problems.h" +#include "../../common/cutlass_unit_test.h" +#include "../../gemm/device/testbed_utils.h" + +#include "cutlass/matrix_coord.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_reduce.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Conv, + template class ActivationFunctor +> +struct TestbedConv2dWithAbsMax { + + using ElementAccumulator = typename Conv::ElementAccumulator; + using ElementCompute = typename Conv::UnderlyingKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Conv::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Conv::EpilogueOutputOp::ElementAbsmax; + static cutlass::conv::Operator const kConvolutionalOperator = Conv::kConvolutionalOperator; + + static bool const kScaleAux = Conv::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Conv::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedConv2dWithAbsMax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Vector.resize({1, 1, 1, implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()}); + reference_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + tmp_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<4> origin(0); + tensor_A.host_view().at(origin) = typename Conv::ElementA(1); + tensor_B.host_view().at(origin) = typename Conv::ElementB(1); + tensor_C.host_view().at(origin) = typename Conv::ElementC(1); + tensor_Vector.host_view().at(origin) = typename Conv::ElementC(1); + + cutlass::reference::host::TensorFill(tensor_D.host_view()); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1, 1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + reference_abs_max_Aux.resize({1, 1, 1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view()); + passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view()); + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view()); + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file0("conv_testbed_with_amax_errors_reference.txt"); + std::ofstream file1("conv_testbed_with_amax_errors_computed.txt"); + + std::ofstream file("conv_testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << std::endl; + + file0 << "\n\nReference D =\n" << reference_D.host_view() << std::endl; + file1 << "\n\nComputed D =\n" << tensor_D.host_view() << std::endl; + if (kScaleAux) { + file0 << "\n\nReference Aux =\n" << reference_Aux.host_view() << std::endl; + file1 << "\n\nComputed Aux =\n" << tensor_Aux.host_view() << std::endl; + file0 << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() << std::endl; + file1 << "\n\nComputed Absmax Aux = " << abs_max_Aux.host_view() << std::endl; + } + if (kScaleOutput) { + file0 << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() << std::endl; + file1 << "\n\nComputed Absmax D = " << abs_max_D.host_view() << std::endl; + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<4> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + cutlass::reference::host::Conv2d< + typename Conv::ElementA, typename Conv::LayoutA, + typename Conv::ElementB, typename Conv::LayoutB, + typename Conv::ElementC, typename Conv::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tmp_D.host_ref(), + scaled_alpha, + scaled_beta + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({n, p, q, k})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, 0, 0, k})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({n, p, q, k}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({n, p, q, k}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + } + } + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Conv::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Conv::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + typename Conv::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + tensor_Aux.device_ref(), + epilogue_params, + cutlass::conv::SplitKMode::kSerial, + tensor_Vector.device_data(), + 0 + }; + + Conv conv2d_op; + + cutlass::Status status = conv2d_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Conv::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv2d_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed" << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ImplicitGemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllConv2dWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(); + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector(); + + // + // Testbed object + // + + TestbedConv2dWithAbsMax testbed(scaleA, scaleB, scaleC); + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + bool passed = true; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Prune all problems with channels that aren't divisible by the number of elements accessed per + // load for operands A and B. This is meant to align with the requirements of iterators used for + // fprop kernels. + ChannelDivisibilitySpecification channel_spec(128 / cutlass::sizeof_bits::value); + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed &= testbed.run(conv_problem); + + if (!passed) { + return false; + } + + // test mode = convolution + passed &= testbed.run(conv_problem.reset_mode(cutlass::conv::Mode::kConvolution)); + + if (!passed) { + return false; + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f768f5b25f425910a49058599d3854352136caef --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -0,0 +1,734 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv2dWithBroadcastReferenceOp { + + using OutputOp = typename Conv2d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv2dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv2d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) +// +// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) +// + +template < + typename Conv2d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv2dWithBroadcast { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv2dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Broadcast.resize({ + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); + } + } + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, p, q, k}) = z; + tensor_T_reference.at({n, p, q, k}) = t; + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv2dWithBroadcast( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv2dWithBroadcast( + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..a8ec16ca5de369470f5dc50bb6f8b5e2da3da10d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -0,0 +1,643 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/tensor_reduce.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2dWithReduction { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementT = typename EpilogueOutputOp::ElementTensor; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_Final_Reduction; + + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv2dWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope = 2; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + tensor_Reduction.resize({ + 1, + 1, + (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM, + (problem_size.K) + }); + + tensor_Final_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); + + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + static_cast(tensor_Reduction.stride()[0]), + static_cast(tensor_Tensor.stride()[0]) + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // Final reduction over the partial reduction tensor + using Functor = cutlass::plus; + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementAccumulator, + ElementAccumulator, + LayoutC, + Functor, + 8, + ElementAccumulator + >; + + TensorReduction reduction(tensor_Reduction.extent(), 2); + + cutlass::DeviceAllocation reduction_device_workspace(reduction.workspace_size()); + + status = reduction.reduce( + tensor_Final_Reduction.device_ref(), + tensor_Reduction.device_ref(), + reduction_device_workspace.get(), + ElementAccumulator()); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + // + // Reference check + // + + tensor_D_computed.sync_host(); + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + + EXPECT_TRUE(passed); + + // + // Reference check on reduction results + // + + tensor_Reduction.sync_host(); + tensor_Final_Reduction.sync_host(); + + // compute backwards for reduction results + cutlass::HostTensor reference_Reduction; + reference_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + for (int k = 0; k < problem_size.K; ++k) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + reduced_value += tensor_D_reference.at({n, p, q, k}); + } + } + } + reference_Reduction.at({0, 0, 0, k}) = reduced_value; + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Final_Reduction.host_view(), + reference_Reduction.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" + << "\nD computed:\n" << tensor_D_computed.host_view() << "\n" + << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n" + << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2dWithReduction( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithReduction testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + // Parallel SplitK is not tested. + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..fae7d6194fb671594221a90faea7cac1e5fbeb9f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv3dProblemVector = std::vector; + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv3dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv3dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + Conv3dProblemVector conv3d_default_sizes; + Conv3dProblemVector conv3d_vnet_medical_sizes; + + // + // Methods + // + /// Default ctor + TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + + initialize_conv3d_default_sizes(); + initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/); + + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv3dProblemVector *problems_vectors[] = { + &conv3d_default_sizes, + &conv3d_vnet_medical_sizes + }; + + for (Conv3dProblemVector *problems : problems_vectors) { + Conv3dProblemVector filtered; + + for (cutlass::conv::Conv3dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv3d_default_sizes() { + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 15, 19, 160}, // input size (NDHWC) + {224, 1, 3, 6, 160}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC) + {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC) + {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 11, 15, 19, 64}, // input size (NDHWC) + {32, 4, 3, 6, 64}, // filter size (KTRSC) + cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + } + + // Add vnet layers to unit testing sizes + void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) { + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {32, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {64, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {128, 2, 2, 2, 64}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 4, 4, 4, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {64, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {128, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + } + +}; + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..029f5effb9103bebd4ee61767795d3883541d986 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h @@ -0,0 +1,716 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/util/reference/host/convolution.h" + +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "conv3d_problems.h" +#include "cutlass/core_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv3d { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv3d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute()) { + + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv3d conv3d_op; + + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + cutlass::Status status = conv3d_op.can_implement(conv3d_args); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_size: \n"; + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv3d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv3d output is written to workspace in global memory + conv3d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv3d_args.output_op = {1.0, 0.0}; + // update conv3d operator arguments + status = conv3d_op.update(conv3d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv3dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv3d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta + ); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta + ); +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv3d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "ndhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_ktrsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv3d( + const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + //TestbedConv3d testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential); + TestbedConv3d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity) || + (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == + cutlass::conv::StrideSupport::kUnity))) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template +bool TestSpecificConv3d( + const Conv3dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3d testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ba785c9d0ecbdd518711714558c9e166c0209a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -0,0 +1,732 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv3d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv3dWithBroadcastReferenceOp { + + using OutputOp = typename Conv3d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv3dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv3d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv3d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, o, p, q, k] = ReductionOp(Y[n, o, p, q, k], Broadcast[k]) +// +// Z[n, o, p, q, k] = Elementwise(T[n, o, p, q, k]) +// + +template < + typename Conv3d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv3dWithBroadcast { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv3dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) { + + // to make the layout of tensors a little bit bigger than the problem size + cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64); + + cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size); + + if (non_packed_test) { + tensor_A_extent += stride_increment; + tensor_C_extent += stride_increment; + } + + tensor_A.resize(tensor_A_extent); + tensor_B.resize(tensor_B_extent); + tensor_C.resize(tensor_C_extent); + tensor_C_reference.resize(tensor_C_extent); + tensor_Z_computed.resize(tensor_C_extent); + tensor_Z_reference.resize(tensor_C_extent); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(tensor_C_extent); + tensor_Broadcast.resize({ + 1, + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int o = 0; o < tensor_C_reference.extent().d(); ++o) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, o, p, q, k}) = ElementAccumulator(tensor_C.at({n, o, p, q, k})); + } + } + } + } + } + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + bool non_packed_test = false, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv3d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size, non_packed_test); + + // configure the operator + Conv3d conv3d_op; + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, o, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, o, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, o, p, q, k}) = z; + tensor_T_reference.at({n, o, p, q, k}) = t; + } + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nnhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv3dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv3dWithBroadcast( + const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(), + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv3d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + false,/*non_packed_test*/ + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv3dWithBroadcast( + const Conv3dProblemVector & problem_sizes, + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross, non_packed_test = false + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution, non_packed_test = false + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..cef5f981c595dfbbb95658fb757865b219538192 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h @@ -0,0 +1,473 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Depthwise Direct Conv testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "../cache_testbed_output.h" +#include "conv2d_problems.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/cutlass.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedDepthwiseDirectConv2d { + public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + public: + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_reordered_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + + public: + TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} + + /// Helper to initialize a tensor view + template + void initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } else { + scope = 5; + } + } else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } else { + } + } + + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_reordered_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient(int smem_size) const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < static_cast(smem_size)) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1.5), + ElementCompute beta = ElementCompute(1)) { + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " + << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") + << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args(problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + tensor_reordered_B.device_ref(), + split_k_mode); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.can_implement(problem_size); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + if (!sufficient(conv2d_op.get_smem_size())) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = + CreateCachedConv2dTestKey(kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view()); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_DirectConv_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { + bool passed = true; + + // + // Testbed object + // + TestbedDepthwiseDirectConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (auto conv_problem : problem_sizes) { + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54c11281e14b813b249d7f9710542843b37bcc68 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp @@ -0,0 +1,1385 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS 3.x Implicit GEMM testbed sizes for ConvNd problem +*/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +std::vector> +inline +get_conv_problem_vector(); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Fprop +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {800, 80, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {512, 64, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nqk) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 4, 64}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {8000, 800, 80, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {4096, 512, 64, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (npqk) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 3, 3, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,2/1,2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 2, 5, 64}, + {1, 1}, + {2, 2}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 7, 7, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, + {16, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, + {96, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 3, 3, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 4, 32}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1024, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1040, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 15, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 64, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 65, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 64, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 65, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Grouped Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Get problem size vectors for group conv problems +template +std::vector> +inline +get_grouped_conv_problem_vector(int GroupsPerTile); + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_grouped_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>(int GroupsPerTile) { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + + if (GroupsPerTile == 1) { + // channel_per_group == 64 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 2048}, // ndhwc + {2048, 1, 3, 3, 64}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 2) { + // channel_per_group == 32 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 1024}, // ndhwc + {1024, 1, 3, 3, 32}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 4) { + // channel_per_group == 16 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 512}, // ndhwc + {512, 1, 3, 3, 16}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 8) { + // channel_per_group == 8 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 256}, // ndhwc + {256, 1, 3, 3, 8}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Unit Stride Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {512, 64, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 16}, + {64, 1, 16}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 96}, + {64, 1, 96}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 256}, + {64, 1, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 3, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {32, 3, 256}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 4, 256}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {8000, 800, 80, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {4096, 512, 64, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (nhwc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 16}, + {64, 1, 1, 16}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 96}, + {64, 1, 1, 96}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 256}, + {64, 1, 1, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 3, 3, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {32, 3, 3, 256}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 2, 5, 256}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 16}, + {64, 1, 1, 1, 16}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // nzpqk + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 96}, + {64, 1, 1, 1, 96}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Test TMA truncation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 512, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {2}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1024, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 2048, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {8}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride divides dilation + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {1}, // padding upper (pad_w) + {2}, // stride (stride_w) + {4}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // dilation divides stride + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {2}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride dilation dont divide + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {2}, // padding upper (pad_w) + {2}, // stride (stride_w) + {3}, // dilation (dilation_w) + 1 // group + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 stride divides dilation + // mode 1 dilation divides stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {2, 4}, + {4, 2}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 dilation divides stride + // mode 1 stride divides dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {4, 2}, + {2, 4}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // stride dilation dont divide + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {3, 2}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {2, 1, 2}, + {4, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99ba9c407cec38e919812fedeee38ba75d9129f7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp @@ -0,0 +1,768 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed for 3.x API +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "../../common/cutlass_unit_test.h" + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "../test/unit/gemm/device/gemm_testbed_3x.hpp" + +#include "thrust/universal_vector.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/conv.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "conv_problem_sizes.hpp" +#include "../cache_testbed_output.h" + +#include + +#include "cute/layout.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Initializes a flat device buffer +template +static void +initialize_values( + thrust::universal_vector& dst_ptr, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (cutlass::Distribution::Uniform == dist_kind) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0); + } + else if (cutlass::Distribution::Identity == dist_kind) { + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0); + } + else if (cutlass::Distribution::Gaussian == dist_kind) { + cutlass::reference::host::BlockFillRandomGaussian(dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0.5); + } + else if (cutlass::Distribution::Sequential == dist_kind) { + cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size()); + } + else { + std::cerr << "Invalid distribution kind!\n."; + exit(1); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// utils for sparse or dense conv parameters + +template +struct DenseConvParams { + // Default Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + + // get the default arguments without sparse data + auto get_mainloop_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + thrust::universal_vector& tensor_A, + thrust::universal_vector& tensor_B + ) { + auto args = typename Conv::ConvKernel::MainloopArguments { + tensor_A.data().get(), + tensor_B.data().get(), + }; + return args; + } +}; + +template +struct SparseConvParams { +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ConvTestbed { + // Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + using ElementC = cute::conditional_t, + typename Conv::ConvKernel::ElementD, typename Conv::ConvKernel::ElementC>; + using ElementD = typename Conv::ConvKernel::ElementD; + using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator; + + // ConvTest for sparse kernel + static constexpr bool isSparseEnabled = isSparseEnabled_; + using ConvParams = cute::conditional_t, DenseConvParams>; + ConvParams params; + + // + // FusionOperation derived types/queries + // + using FusionOp = typename Conv::EpilogueOutputOp; + + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementCompute = typename FusionOp::ElementCompute; + using BiasType = typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::type; + using ElementBias = non_void_t; + using ActivationType = non_void_t::type, + cutlass::epilogue::thread::Identity>; + static constexpr bool IsActivationEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation::value; + using ActivationFunctor = cute::conditional_t>; + + static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::value && + !cute::is_same_v; + static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled::value; + + static constexpr bool DisableSource = cute::is_void_v; + + static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd::value; + + using StrideC = typename Conv::ConvKernel::StrideC; + using StrideD = typename Conv::ConvKernel::StrideD; + using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + using Schedule = typename Conv::DispatchPolicy::Schedule; + /// Initialization + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros + uint64_t seed = 6090; + float epsilon = 0.0f; + int split_p_slices = 1; + thrust::universal_vector tensor_A; + thrust::universal_vector tensor_B; + thrust::universal_vector tensor_C; + thrust::universal_vector tensor_D_computed; + thrust::universal_vector tensor_D_reference; + thrust::universal_vector tensor_bias; + thrust::universal_vector tensor_alpha; + thrust::universal_vector tensor_beta; + + // Return true on success, else false + bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) { + tensor_A.resize(sizeof(ElementA) * problem_shape.size_A()); + tensor_B.resize(sizeof(ElementB) * problem_shape.size_B()); + tensor_C.resize(sizeof(ElementC) * problem_shape.size_C()); + tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + if constexpr (IsPerChannelScaleEnabled) { + tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + } + initialize_values(tensor_A, init_A, seed); + initialize_values(tensor_B, init_B, seed * 11); + initialize_values(tensor_C, init_C, seed * 17); + initialize_values(tensor_bias, init_bias, seed * 19); + if constexpr (IsPerChannelScaleEnabled) { + initialize_values(tensor_alpha, init_bias, seed * 23); + if constexpr (DisableSource) { + initialize_values(tensor_beta, init_disable, seed * 27); + } + else { + initialize_values(tensor_beta, init_bias, seed * 27); + } + } + + bool flag = true; + if constexpr (isSparseEnabled) { + flag &= params.initialize(problem_shape, tensor_B, static_cast(seed + 2023)); + } + + return flag; + } + + // Determine SMEM requirements and waive if not satisfied + bool sufficient() const { + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + int max_smem_size; + result = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaDeviceGetAttribute() failed"); + } + + return max_smem_size >= Conv::ConvKernel::SharedStorageSize; + } + + auto transform_shape_and_stride_with_groups(ProblemShape const& problem_shape) { + using TensorExtent = cute::array; + using TensorStride = cute::array; + + TensorExtent shape_a_g{}; + TensorExtent shape_b_g{}; + TensorExtent shape_c_g{}; + TensorStride stride_a_g{}; + TensorStride stride_b_g{}; + TensorStride stride_c_g{}; + + auto shape_a = cute::reverse(problem_shape.shape_A); + auto shape_b = cute::reverse(problem_shape.shape_B); + auto shape_c = cute::reverse(problem_shape.shape_C); + auto stride_a = cute::reverse(problem_shape.stride_A); + auto stride_b = cute::reverse(problem_shape.stride_B); + auto stride_c = cute::reverse(problem_shape.stride_C); + + int32_t G = problem_shape.groups; + + if constexpr (ConvOp == cutlass::conv::Operator::kFprop || + ConvOp == cutlass::conv::Operator::kDgrad) { + // shape_a_g = (c,w,h,d,n,g) or (k,q,p,z,n,g) + // shape_b_g = (c,s,r,k,t,g) + // shape_c_g = (k,q,p,z,n,g) or (c,w,h,d,n,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_b), + cute::make_shape(cute::size(shape_b) / G, G))); + shape_c_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_c) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_c), + cute::make_shape(G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, + cute::size(stride_b) * cute::size(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, cute::size<0>(shape_c) / G)); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + // shape_a_g = (k,q,p,z,n,g) + // shape_b_g = (c,w,h,d,n,g) + // shape_c_g = (c,s,r,k,t,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_b) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_b), + cute::make_shape(G))); + shape_c_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_c), + cute::make_shape(cute::size(shape_c) / G, G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, cute::size<0>(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, + cute::size(stride_c) * cute::size(shape_c) / G)); + } + + return make_tuple(shape_a_g, shape_b_g, shape_c_g, + stride_a_g, stride_b_g, stride_c_g); + } + + // Executes one test + bool run( + ProblemShape const& problem_shape, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + MaxSwizzleSize max_swizzle = MaxSwizzleSize{}, + Splits splits = Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + ) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device.\n"; + } + return true; + } + + bool ret = initialize(problem_shape); + + if (!ret) { + std::cerr << "initialize failed for the given problem_shape: \n"; + return false; + } + + cutlass::KernelHardwareInfo hw_info; + cudaGetDevice(&hw_info.device_id); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + hw_info.cluster_shape = cluster_shape; + hw_info.cluster_shape_fallback = cluster_shape_fallback; + + // configure the operator + Conv conv_op; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + } + // Need to support non-packed output strides for fprop and dgrad kernel. + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { + cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { + cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{}; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + + auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B); + + auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments { + {}, + tensor_C.data().get(), + stride_C, + tensor_D_computed.data().get(), + stride_D, + }; + + auto args = typename Conv::Arguments { + problem_shape, + mainloop_args, // MainloopArguments + epilogue_args, // EpilogueArguments + hw_info, + scheduler_args + }; + + auto &fusion_args = args.epilogue.thread; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + fusion_args.alpha_ptr = tensor_alpha.data().get(); + fusion_args.beta_ptr = tensor_beta.data().get(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = tensor_bias.data().get(); + } + + // Clamp bound + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); + fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); + } + + // Scale + if constexpr (cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> ) { + fusion_args.activation.scale = ElementCompute{1}; + } + + // LeakyRelu + if constexpr (cute::is_same_v> ) { + fusion_args.activation.leaky_alpha = ElementCompute{0}; + } + + cutlass::Status status = cutlass::Status::kInvalid; + + status = conv_op.can_implement(args); + EXPECT_EQ(conv_op.can_implement(args), cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_shape: \n"; + print(problem_shape); + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv::get_workspace_size(args); + thrust::universal_vector workspace(workspace_size); + + status = conv_op.initialize(args, workspace.data().get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: " + << cudaGetErrorString(result); + + // Create cute::Tensors using the logical rank-3 MNK multi-mode shapes the mainloop gives us + auto [shape_mA, shape_mB, shape_mC, stride_mA, stride_mB, stride_mC] = + transform_shape_and_stride_with_groups(problem_shape); + auto shape_mBias = cute::make_shape(cute::size(cute::get<0>(problem_shape.get_shape_B()))); + + auto mA = make_tensor(tensor_A.data().get(), make_layout(shape_mA, stride_mA)); + auto mB = make_tensor(tensor_B.data().get(), make_layout(shape_mB, stride_mB)); + auto mC = make_tensor(tensor_C.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_ref = make_tensor(tensor_D_reference.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_computed = make_tensor(tensor_D_computed.data().get(), make_layout(shape_mC, stride_mC)); + auto mBias = make_tensor(tensor_bias.data().get(), make_layout(shape_mBias)); + auto mAlpha = make_tensor(tensor_alpha.data().get(), make_layout(shape_mBias)); + auto mBeta = make_tensor(tensor_beta.data().get(), make_layout(shape_mBias)); + + cutlass::reference::host::ConvEpilogueFusionParams< + ElementAccumulator, + ElementScalar, + ElementCompute, + ElementC, + ElementD, + IsResidualEnabled, + decltype(mAlpha), + decltype(mBeta), + decltype(mBias), + ActivationFunctor> + epilogue_fusion_params{}; + + epilogue_fusion_params.alpha = alpha; + epilogue_fusion_params.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + epilogue_fusion_params.tensor_alpha = mAlpha; + epilogue_fusion_params.tensor_beta = mBeta; + } + + if constexpr (IsBiasEnabled) { + epilogue_fusion_params.tensor_bias = mBias; + } + + auto padding = cute::reverse(problem_shape.lower_padding); + auto tstride = cute::reverse(problem_shape.traversal_stride); + auto dilation = cute::reverse(problem_shape.dilation); + + cutlass::reference::host::ConvReferenceImpl< + ConvOp, + NumSpatialDimensions, + decltype(mA), + decltype(mB), + decltype(mC), + decltype(mD_ref), + decltype(padding), + decltype(tstride), + decltype(dilation), + decltype(epilogue_fusion_params)> + reference_impl(mA, mB, mC, mD_ref, padding, tstride, dilation, epilogue_fusion_params); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConvNd3xTestKey< + ProblemShape, + ElementA, + ElementB, + ElementC, + ElementD + >( + ConvOp, + problem_shape, + alpha, + beta, + tensor_A, + tensor_B, + tensor_C + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string convnd_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + CachedTestResultListing cached_results(convnd_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + #endif + + if (!cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + cached_test_result.D = TensorHash(tensor_D_reference); + CachedTestResultListing cached_results(convnd_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(convnd_result_cache_name); + #endif + } // if (!cached_result_loaded) + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + uint32_t tensor_D_computed_hash = TensorHash(tensor_D_computed); + passed = (tensor_D_computed_hash == cached_test_result.D); + // If hash fails, double check against reference implementation. + if(!passed) { + std::cerr << "Hash-based comparison unsuccessful for key:" << "\n" << cached_test_key + << ", comparing with reference implementation now.\n"; + if (cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + } + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + } + #else + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + #endif + + EXPECT_TRUE(passed); + return passed; + } + + template< + class Engine, class Layout, + class EngineA, class LayoutA, + class EngineB, class LayoutB, + class EngineAlpha, class LayoutAlpha, + class EngineBeta, class LayoutBeta, + class EngineBias, class LayoutBias> + static constexpr bool + compare_reference( + cute::Tensor const& reference, + cute::Tensor const& computed, + cute::Tensor const& A, + cute::Tensor const& B, + cute::Tensor const& tensor_alpha, + cute::Tensor const& tensor_beta, + cute::Tensor const& tensor_bias, + float epsilon = 0.0f) { + if (size(reference) != size(computed)) { + return false; + } + + bool passed = true; + if (epsilon == 0.0f) { + // fast refcheck w/o epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + if (reference(i) != computed(i)) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } else { + // refcheck with epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + auto ref = static_cast(reference(i)); + auto act = static_cast(computed(i)); + auto abs_error = std::abs(act - ref); + auto rel_error = abs_error / (std::max(std::abs(act), std::abs(ref)) + 0.00001f); + if (std::isnan(abs_error) || std::isnan(rel_error) || + std::min(abs_error, rel_error) > epsilon) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } + #if CUTLASS_DEBUG_TRACE_LEVEL > 1 + if (not passed) { + cute::print("Reference:"); + cute::print_tensor(reference); + cute::print("\nComputed:"); + cute::print_tensor(computed); + cute::print("\n"); + + for (size_t i = 0; i < size_t(size(A)); ++i) { + printf("[%llu]: A = %f\n", static_cast(i), float(A(i))); + } + for (size_t i = 0; i < size_t(size(B)); ++i) { + printf("[%llu]: B = %f\n", static_cast(i), float(B(i))); + } + if constexpr (IsPerChannelScaleEnabled) { + for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) { + printf("[%llu]: alpha = %f\n", static_cast(i), + float(tensor_alpha(i))); + } + for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) { + printf("[%llu]: beta = %f\n", static_cast(i), + float(tensor_beta(i))); + } + } + if constexpr (IsBiasEnabled) { + for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) { + printf("[%llu]: bias = %f\n", static_cast(i), + float(tensor_bias(i))); + } + } + for (size_t i = 0; i < size_t(size(reference)); ++i) { + printf("[%llu]: ref = %f, computed = %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + } + } + #endif + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f, + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0) + ) { + using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar; + + bool passed = true; + ConvTestbed testbed; + testbed.epsilon = epsilon; + auto problem_vector = get_conv_problem_vector< + Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>(); + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + + for (auto conv_problem : problem_vector) { + #if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print(conv_problem); + #endif + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(Splits{2}); + problem_splits.push_back(Splits{4}); + } + } + for (auto splits : problem_splits) { + + passed = testbed.run( + conv_problem, + cutlass::from_real(alpha), + cutlass::from_real(beta), + cluster_shape, + cluster_shape_fallback, + RasterOrderOptions::Heuristic, // raster_order + MaxSwizzleSize(1), + splits, + decomp_mode + ); + if (!passed) { + printf("Failed test for "); print(conv_problem); + return false; + } + } // splits + } // decomposition_mode + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff170be142ff9d0d02cc684c2873c3ec014bd236 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +template +__global__ void +test_tiled_cp_async_device_cute(T const* g_in, T* g_out, + TiledCopy const tiled_copy, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + auto thr_copy = tiled_copy.get_slice(threadIdx.x); + Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); + + auto tAgA = thr_copy.partition_S(gA); + auto tAsA = thr_copy.partition_D(sA); + +#if 0 + if (thread0()) { + print("gA : "); print(gA.layout()); print("\n"); + print("sA : "); print(sA.layout()); print("\n"); + print("tAgA: "); print(tAgA.layout()); print("\n"); + print("tAsA: "); print(tAsA.layout()); print("\n"); + } +#endif + + copy(tiled_copy, tAgA, tAsA); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Store trivially smem -> gmem + + if (thread0()) { + copy(sA, gB); + } + +} + +template +void +test_tiled_cp_async( + TiledCopy const tiled_copy, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tiled_copy, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < size(hA_out) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +template +void test_cp_async_no_swizzle() { + using namespace cute; + auto smem_atom = SMEM_LAYOUT{}; + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} + +template +void test_cp_async_with_swizzle() { + using namespace cute; + auto swizzle_atom = SWIZZLE_ATOM{}; + auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{}); + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3ff20d4087ee2fd6f4f74338e3e63eef27c221d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/relatively_equal.h" +#include "cutlass_unit_test.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct fp64_tester { + using value_type = double; +}; + +template +struct fp64_tester> { + using value_type = complex; +}; + +template // logical shape (M, N) +auto host_generate_gemm_inputs( + ALayout a_layout, + BLayout b_layout, + CLayout c_layout +) { + thrust::host_vector h_a(cosize(a_layout)); + thrust::host_vector h_b(cosize(b_layout)); + thrust::host_vector h_c(cosize(c_layout)); + thrust::host_vector h_c_out(cosize(c_layout)); + + auto h_a_tensor = make_tensor(h_a.data(), a_layout); + auto h_b_tensor = make_tensor(h_b.data(), b_layout); + auto h_c_tensor = make_tensor(h_c.data(), c_layout); + size_t max_size = std::max({static_cast(size(a_layout)), + static_cast(size(b_layout)), + static_cast(size(c_layout))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(a_layout)) { + h_a_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(b_layout)) { + h_b_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(c_layout)) { + h_c_tensor(i) = static_cast((di*di) / size(a_layout)); + } + } + + return std::make_tuple(h_a, h_b, h_c, h_c_out); +} + +template +thrust::host_vector +host_reference_gemm(Alpha alpha, + Tensor const& h_a_tensor, + Tensor const& h_b_tensor, + Beta beta, + Tensor const& h_c_tensor, + ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) + { + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TA = remove_cv_t; + using TB = remove_cv_t; + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + thrust::host_vector h_c_ref(cosize(h_c_tensor.layout()), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout()); + // A * B + for (int k = 0; k < size<1>(h_a_tensor); k++) { + for (int m = 0; m < size<0>(h_a_tensor); m++) { + for (int n = 0; n < size<0>(h_b_tensor); n++) { + const auto a_value = a_load_transform(h_a_tensor(m, k)); + const auto b_value = b_load_transform(h_b_tensor(n, k)); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); + h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); + } + } + } + // C = A*B + C + for (int i = 0; i < size(h_c_ref_tensor); i++) { + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + } + + return h_c_ref; +} + +template +void verify_gemm_correctness(cute::Tensor const& h_c_out_tensor, + cute::Tensor const& h_c_ref_tensor) +{ + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + for (int i = 0; i < size(h_c_ref_tensor); i++) { + ABC_64 h_c_ref_i = h_c_ref_tensor(i); + ABC_64 h_c_out_i = h_c_out_tensor(i); + double epsilon(0.1f); + double nonzero_floor(std::numeric_limits::min()); + bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); + ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; + } +} + + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + Alpha const alpha, + Beta const beta, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op, + SMemCopyLdOpC c_copy_ld_op, + SMemCopyStOpC c_copy_st_op) +{ + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), smem_c_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + cooperative_gemm( + threadIdx.x, tiled_mma, + alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op + ); + __syncthreads(); + + cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); +} + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op) + { + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Create C fragment for storing intermediate results + auto thr_mma = TiledMma().get_thread_slice(threadIdx.x); + Tensor g_c_partition = thr_mma.partition_C(g_c_tensor); + Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor); + Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition); + + // Create indexing help for predicated GEMMs + Tensor cC = make_identity_tensor(shape(gmem_c_layout)); + Tensor tCcC = thr_mma.partition_C(cC); + + // Load C from global + // (always loading in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + r_c_partition(i) = c_load_transform(g_c_partition(i)); + } + } + + cooperative_gemm( + threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition, + a_load_transform, b_load_transform, a_copy_op, b_copy_op + ); + + __syncthreads(); + + // Store C to global + // (always storing in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + g_c_out_partition(i) = c_store_transform(r_c_partition(i)); + } + } +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}, + CSMemCopyLdOp c_smem_copy_ld_op = {}, + CSMemCopyStOp c_smem_copy_st_op = {}) +{ + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM + static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.1); + const auto beta = static_cast(1.2); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + + sizeof(TC) * h_c.size(); + + + auto kernel = cooperative_gemm_kernel< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, SMemCLayout, + TA, TB, TC, decltype(alpha), decltype(beta), + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + alpha, + beta, + tiled_mma, + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform, + a_smem_copy_op, + b_smem_copy_op, + c_smem_copy_ld_op, + c_smem_copy_st_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Copy result data + h_c_out = d_c_out; + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}) +{ + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.0); + const auto beta = static_cast(1.0); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = + host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), static_cast(-1)); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes); + + + auto kernel = cooperative_gemm_kernel_rmem_c< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, + TA, TB, TC, + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + tiled_mma, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_smem_copy_op, b_smem_copy_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Copy result data + h_c_out = d_c_out; + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template +void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma, + ops...); +} + + +template +std::enable_if_t, + cute::is_layout, + cute::is_layout>> +test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + SMemAtomLayoutC smem_atom_layout_c, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops&& ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_atom_layout_c, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + ops...); +} + + +template +void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + + test_cooperative_gemm_rmem_c + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + tiled_mma, + ops...); +} + +template +std::enable_if_t, + cute::is_layout>> +test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + test_cooperative_gemm_rmem_c + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + tiled_mma, + ops...); +} + +template +void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout_rmem_c, + T, T, T> + (static_cast(args)...); +} + +template +void test_cooperative_gemm_col_major_layout(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout, + T, T, T> + (static_cast(args)...); +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d2620e62ff247e36ae49809ab4ef3416560ae31 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp @@ -0,0 +1,217 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_LOAD + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA_x: "); print(tAgA_x); print("\n"); + print("tAsA_x: "); print(tAsA_x); print("\n"); + } +#endif + + // + // Perform the TMA_LOAD + // + + // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST) + Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST) + static_assert(size<1>(tAsA) == 1); + + // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output + Tensor tBgB = group_modes<0,R>(group_modes(gB)); // (CTA_TILE, REST) + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + } +#endif + + // Test L2 prefetch + if (threadIdx.x == 0) { + prefetch(tma, tAgA); + } + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(make_tensor_like(tensor<0>(tAsA))); + + if (threadIdx.x == 0) + { + /// Initialize shared memory barrier + tma_load_mbar[0] = 0; + cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value + constexpr int kPhaseBit = 0; + cute::wait_barrier(tma_load_mbar[0], kPhaseBit); + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3e0ec46df1b672c35c3c38f731c09b0134d4cd80 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include +#include +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout smem_layout, + CUTE_GRID_CONSTANT CopyAtom const tma, CTA_Tiler cta_tiler, Cluster_Size cluster_size) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + Tensor gA = zipped_divide(mA, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + Tensor gB = zipped_divide(mB, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + +#if 1 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Prepare the TMA_LOAD + // + + Tensor sA_x = make_tensor(sA.data(), make_layout(sA.layout(), Layout<_1>{})); // ((CTA_TILE_M,CTA_TILE_N,...),_1) + Tensor tBgB = gB; // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto [tAgA, tAsA] = tma_partition(tma, cta_rank_in_cluster, make_layout(cluster_size), sA_x, gA); + +#if 1 + if (thread0()) { + print("sA_x : "); print(sA_x); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // TMA Multicast Masks -- Get a mask of the active ctas in each TMA + // + + + int elected_cta_rank = 0; + bool elect_one_cta = (elected_cta_rank == cta_rank_in_cluster); + bool elect_one_thr = cute::elect_one_sync(); + + uint16_t tma_mcast_mask = ((uint16_t(1) << cluster_size) - 1); + +#if 1 + if (thread0()) { + print("tma_mcast_mask : "); print(tma_mcast_mask); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Perform the TMA_LOAD + // + + if (elect_one_thr) { + // Initialize TMA barrier + cute::initialize_barrier(tma_load_mbar[0], /* num_threads */ 1); + } + int tma_phase_bit = 0; + // Ensures all CTAs in the Cluster have initialized + __syncthreads(); + cute::cluster_sync(); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + + if (elect_one_thr) + { + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0], tma_mcast_mask), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from tma_phase_bit value + cute::wait_barrier(tma_load_mbar[0], tma_phase_bit); + tma_phase_bit ^= 1; + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (elect_one_cta && elect_one_thr) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + cute::cluster_sync(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_atom(copy_op, gA, smem_layout, cta_tiler, cluster_size); + //print(tma); + + // Launch + + dim3 dimBlock(32); + dim3 dimCluster(size(cluster_size)); + dim3 dimGrid = dimCluster; + int smem_size = sizeof(SharedStorage); + + void* kernel_ptr = (void*) &tma_test_device_cute; + + cutlass::launch_kernel_on_cluster({dimGrid, dimBlock, dimCluster, smem_size}, + kernel_ptr, + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast(raw_pointer_cast(d_out.data())), + gmem_layout, + smem_layout, + tma, cta_tiler, cluster_size); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0429d2435fbf43c690f311c1f7c04f7025a2dd94 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor mB = tma.get_tma_tensor(shape(gmem_layout)); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_STORE + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N) + Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mB : "); print( mB.data()); print(" o "); print( mB.layout()); print("\n"); + print(" gB : "); print( gB.data()); print(" o "); print( gB.layout()); print("\n"); + print("tBgB_x: "); print(tBgB_x.data()); print(" o "); print(tBgB_x.layout()); print("\n"); + print(" sB : "); print( sB.data()); print(" o "); print( sB.layout()); print("\n"); + print("tBsB_x: "); print(tBsB_x.data()); print(" o "); print(tBsB_x.layout()); print("\n"); + } +#endif + + // + // Perform the TMA_STORE + // + + // INPUT: Group the CTA_TILE_X modes and REST_X modes for input + Tensor tAgA = group_modes<0,R>(group_modes(gA)); // (CTA_TILE, REST) + + // OUTPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tBgB = group_modes<1,rank(tBgB_x)>(tBgB_x); // (TMA,REST) + Tensor tBsB = group_modes<1,rank(tBsB_x)>(tBsB_x); // (TMA,REST) + static_assert(size<1>(tBsB) == 1); + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); + print("tBsB : "); print(tBsB.data()); print(" o "); print(tBsB.layout()); print("\n"); + print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); + } +#endif + + // Test L2 prefetch + cooperative_prefetch<128>(threadIdx.x, gA); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tBgB); ++stage) + { + // + // Read in trivially gmem -> smem + // + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(tAgA(_,stage), sB); + } + + __syncthreads(); + cute::cp_async_wait<0>(); + + // + // Perform the TMA_STORE + // + + if (threadIdx.x == 0) { + copy(tma, tBsB(_,0), tBgB(_,stage)); + } + + tma_store_wait<0>(); + __syncthreads(); + } +} + +template +void +test_tma_store(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_out.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..3163a0d0eaa24513ee210bd2b310d1bf233773a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h @@ -0,0 +1,417 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_with_reduction_threadblock( + typename Epilogue::ElementVector *ptr_Reduction, + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::TensorTileIterator::Params params_Tensor, + typename Epilogue::TensorTileIterator::Element *ptr_Tensor, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::TensorTileIterator iterator_T( + params_Tensor, + ptr_Tensor, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueWithReductionTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementTensor = typename Epilogue::TensorTileIterator::Element; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + cutlass::HostTensor additional_tensor; + cutlass::HostTensor reduction_tensor; + + +public: + + // + // Methods + // + + EpilogueWithReductionTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + reduction_tensor({1, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + + cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); + } + + bool run_all() { + + /* + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + return true; + */ + + double alpha = 1; + double beta = 0; + + return run( + {quantized_size.row(), quantized_size.column()}, + {cutlass::from_real(alpha), cutlass::from_real(beta)}); + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + ElementAccumulator default_reduction = ElementAccumulator(); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + additional_tensor.sync_device(); + reduction_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( + reduction_tensor.device_data(), + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + params_T, + additional_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + reduction_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + // + // The output has two parts: + // - GEMM tensor epilogue in canonical layout + // - partial reduction in canonical row-major layout + // + + // Verify the GEMM tensor output + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord))); + } + else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // Verify the partial reduction + for (int c = 0; c < quantized_size.column(); ++c) { + + ElementAccumulator reduction_acc = ElementAccumulator(); + + for (int r = 0; r < quantized_size.row(); ++r) { + reduction_acc += accumulator_tensor.at({r, c}); + } + + ElementAccumulator expected = default_reduction; + ElementAccumulator got = reduction_tensor.at({0, c}); + + if (c < problem_size.column()) { + expected = reduction_acc; + } + else { + expected = default_reduction; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - reduction element (" << c << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e2457fdb4817e1dfb3af73149ae1e4c4458670a2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h @@ -0,0 +1,356 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/platform/platform.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, iterator_D, accumulators, iterator_C); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + +public: + + // + // Methods + // + + EpilogueTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 2, + -2, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 2, + -2, + 0); + } + + bool run_all() { + + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + ElementCompute intermediate = + output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord)); + + if ((cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || std::numeric_limits::is_integer) + && !std::numeric_limits::is_integer) { + std::fesetround(FE_TONEAREST); + expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); + } else { + expected = ElementOutput(intermediate); + } + } else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) + << ", accum: " << (accumulator_tensor.at(coord)) + << ", source: " << OutputIO(source_tensor.at(coord)) + << ", alpha: " << (output_params.alpha) + << ", beta: " << (output_params.beta) << "\n"; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a76578f7638ac1d30161a9bcb55ecec70b5c43e0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h @@ -0,0 +1,394 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_planar_complex_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + int64_t imaginary_stride_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + int64_t imaginary_stride_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int64_t imaginary_stride_accum, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D_real( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params_D, + ptr_D + imaginary_stride_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C_real( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params_C, + ptr_C + imaginary_stride_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + // + // Load accumulators + // + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + + accumulator_iterator.load(accumulators.real); + accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop so assembly is clearly visible + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpiloguePlanarComplexTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + + using ComplexElementOutput = cutlass::complex; + using ComplexElementAccumulator = cutlass::complex; + using ComplexElementCompute = cutlass::complex; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensorPlanarComplex accumulator_tensor; + cutlass::HostTensorPlanarComplex source_tensor; + cutlass::HostTensorPlanarComplex output_tensor; + +public: + + // + // Methods + // + + EpiloguePlanarComplexTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + #if 1 + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + #else + + cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); + + #endif + } + + bool run_all() { + + cutlass::complex alpha_values[3]; + + alpha_values[0] = cutlass::complex(1, 0); + alpha_values[1] = cutlass::complex(0, 0); + alpha_values[2] = cutlass::complex(2.25f, -0.5f); + + cutlass::complex beta_values[3]; + + beta_values[0] = cutlass::complex(0, 0); + beta_values[1] = cutlass::complex(1, 0); + beta_values[2] = cutlass::complex(0.5f, -2.25f); + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + cutlass::MatrixCoord problem_size( + quantized_size.row() - m_idx * 3, + quantized_size.column() - n_idx * Epilogue::kElementsPerAccess + ); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = run(problem_size, {alpha, beta}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + output_tensor.imaginary_stride(), + params_C, + source_tensor.device_data(), + source_tensor.imaginary_stride(), + output_params, + problem_size, + accumulator_tensor.device_view_real(), + accumulator_tensor.imaginary_stride() + ); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ComplexElementOutput got = output_tensor.at(coord); + + ComplexElementOutput expected = default_output; + + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + ComplexElementOutput src = source_tensor.at(coord); + + ComplexElementCompute tmp = + output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ComplexElementCompute(src.real(), src.imag()); + + expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + + + std::cout << "Incorrect result for problem(" + << problem_size.row() << ", " + << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; + + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + + std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0054a1b6757a232e9177407fdd2041b6a91cffb9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp @@ -0,0 +1,1384 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace cutlass { +namespace gemm { +namespace device { +using namespace cute; + +// This type is only intended to demonstrate porting 2.x kernels to 3.0 +template< + class OperatorClass, class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types { + static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +// +// F16: 128-by-128-by-64 +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,3,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32F16 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + half_t, LayoutA, + half_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_16>>; // 32x32x16 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 8; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + half_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 8; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + half_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + half_t, TagToStrideA_t, + half_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// TF32: 128-by-128-by-kblock (kBlock = 16, 32) +// + +/// Operand A - Row-major (K-major) (kBlock = 32) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Row-major (K-major) (kBlock = 16) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,2,3>{}, + Layout, + Stride<_16, _1>>{})); + using SmemCopyAtom = Copy_Atom; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride< _1,_32>>{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32TF32 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + tfloat32_t, LayoutA, + tfloat32_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 4; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + tfloat32_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 4; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + tfloat32_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + tfloat32_t, TagToStrideA_t, + tfloat32_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _64>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_32>>; // 16x16x32 MMA for LDSM, 1x2x1 value group + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentA = 16; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomA = Copy_Atom; // LDSM works + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentB = 16; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomB = Copy_Atom; + using SmemCopyAtomB = Copy_Atom; // LDSM works + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + int8_t, TagToStrideA_t, + int8_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + int32_t, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// SIMT TWO STAGE /////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_Simt_OperandA; + +/////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,_128>>; + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); +}; + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,Int<128 + 4>>>; // Padded + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + +}; + +template +struct DefaultGemm_Simt_OperandB; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +} // end namespace detail + +// SIMT Two Stage +template < + class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _8>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A + static constexpr int kAlignmentA = 1; + using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 1; + using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + + +// +// DP4A - int8 Proof-of-concept +// + +// SIMT Two Stage TN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; // Tile of atoms (threads) + + // A (M,K) K-major + using ElementA = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomA = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomB = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + + using DispatchPolicy = MainloopSm70TwoStage; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage TT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>, // 32x32x1 MMA with perm for load vectorization + Layout,Stride<_2,_1>>,Underscore>>; + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore,Underscore>>; // 32x16x1 MMA with perm for load vectorization + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore>>; // 16x32x1 MMA with perm for load vectorization + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TN (K-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; + +/* + using EpilogueOutputOp = epilogue::collective::Epilogue< + epilogue::thread::LinearCombination, + Layout, + Stride< _1,_64>>, // SMEM layout + Copy_Atom,double>, // R2S with tiled_mma layout + decltype(make_tiled_copy(Copy_Atom,double>{},// S2R + Layout, + Stride< _1,_16>>{}, // Thread layout + Layout>{})), // Value layout + Copy_Atom,double> // R2G with S2R_dst layout + >; +*/ +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NN (M-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{}));// N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NT (M-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TT (K-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Hopper fp64 MMA TN +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm90, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + make_ordered_layout(Shape<_128,_16>{}, + Step < _2, _1>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + make_ordered_layout(Shape<_64,_16>{}, + Step < _2, _1>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + double, double, + double, cutlass::layout::ColumnMajor, 1, + double, cutlass::layout::ColumnMajor, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89755dd7d3162b114a537e58c6aa33cac80078f9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -0,0 +1,3993 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include // std::lcm + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/detail/collective.hpp" + +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail { + +inline constexpr auto decomp_mode_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::DataParallel) { + return "DataParallel"; + } + else if (mode == Mode::SplitK) { + return "SplitK"; + } + else if (mode == Mode::StreamK) { + return "StreamK"; + } + else { + return "Unknown"; + } + }; + +inline constexpr auto raster_order_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::AlongM) { + return "AlongM"; + } + else if (mode == Mode::AlongN) { + return "AlongN"; + } + else { + return "Unknown"; + } + }; + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +template +struct IsSfdEpi : cute::false_type {}; + +template +struct IsSfdEpi> : cute::true_type {}; + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsLegacyEpiloguePolicy { + static constexpr bool value = false; +}; + +template +struct IsLegacyEpiloguePolicy> { + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool value = cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize>>; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB, + class Enable = void> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + StrideA stride_a; + StrideB stride_b; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (generic)::initialize(problem_shape)"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.resize"); +#endif + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.resize"); +#endif + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an unknown exception"); + throw; + } + + try { + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked_initialize_tensor threw an unknown exception"); + throw; + } + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Check last error before sync_device()"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: cudaGetLastError() is " << error_str); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.host_data()=" << tensor_A.host_data() << ", tensor_A.device_data()=" << tensor_A.device_data()); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.host_data()=" << tensor_B.host_data() << ", tensor_B.device_data()=" << tensor_B.device_data()); + } +#endif + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.sync_device"); +#endif + tensor_A.sync_device(); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.sync_device"); +#endif + tensor_B.sync_device(); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Reached end"); +#endif + return true; + } + + Arguments to_args() { + + + // Runtime datatype selection + if constexpr (not cute::is_same_v) { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b + }; + } + else { + + Arguments arguments = + { + tensor_A.device_data(), stride_a, tensor_B.device_data(), stride_b + }; + return arguments; + } + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + + auto dummy_SFA = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto dummy_SFB = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + bool passed = true; + return passed; + } +}; + +// +// Sparse MMA host implementation +// +template< + class Gemm, + class ElementA_, + class ElementB_> +struct HostCollectiveMainloopSparse +{ + + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ArchTag = typename Gemm::ArchTag; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr int MaxSmCount = 16; + + HostCollectiveMainloopSparse( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloopSparse::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + return true; + } +}; + +template< + class ScheduleType_, + class Gemm, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + typename Gemm::CollectiveMainloop::DispatchPolicy>>> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse MMA input Operands : A_compressed, B, metadata +// +// Structured Sparse Gemm Input Operands + +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideB stride_b; + + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + tensor_A.sync_device(); + tensor_B.sync_device(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Structured Sparse Gemm Input Operands : A_compressed, B, metadata, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm100>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelSparseTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + // return {A, SfA, B, SfB}; + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + cutlass::HostTensor tensor_C; + // Inputs + ElementScalar alpha; + ElementScalar beta; + + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + { + const bool init_succeeded = initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not init_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(init_succeeded); + } + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + auto [M, N, K, L] = problem_shape_MNKL; + + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + static constexpr bool IsLegacy = detail::IsLegacyEpiloguePolicy::value; + + // FFMA2 SGEMM uses ThreadEpilogueOp for bias and relu support instead of FusionOp, so we compose LinCombPerRowBiasEltAct FusionOp by hand to test the functionality. + static constexpr bool IsFfma2Kernel = cute::is_same_v; + using FusionOp = cute::conditional_t, + typename Gemm::EpilogueOutputOp>; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + static constexpr bool IsKMajorSFD = cute::is_same_v; + using ElementSFD = non_void_t; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + cutlass::HostTensor tensor_SFD; + cutlass::HostTensor reference_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsRowBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsColBiasEnabled = FusionOp::IsPerColBiasSupported; + static_assert(not (IsColBiasEnabled && IsRowBiasEnabled)); + + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsPerColScaleEnabled = FusionOp::IsPerColScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor tensor_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // References + cutlass::HostTensor reference_dbias; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If vector scale is supported and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + + try { + const bool initialize_tensor_C_succeeded = + initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not initialize_tensor_C_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(initialize_tensor_C_succeeded); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an unknown exception"); + throw; + } + + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + auto row_vector_coord = cutlass::make_Coord(N); + auto batch_vector_coord = cutlass::make_Coord(L); + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + // scalars + if (vector_scale_mode == VectorScale::DISABLED) { + // batched scalars + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // non-batched scalars + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // batched vectors + else { + auto batched_vector_coord = cutlass::make_Coord((IsPerRowScaleEnabled ? M : N) * L); + alpha.resize(batched_vector_coord, true); + beta.resize(batched_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + // Set alpha beta for different batches. + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + for (int l = 0; l < L; ++l) { + beta.host_view().at(cutlass::make_Coord(l)) = beta_ + ElementScalar(l); + } + } + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + bias.resize(IsRowBiasEnabled ? col_vector_coord : row_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + EXPECT_TRUE(initialize_tensor(tensor_Aux.host_view(), init_C, seed + 2023)); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + } + + if constexpr (IsAuxOutEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + reference_Aux.resize(aux_coord, aux_layout, false); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + if constexpr (IsKMajorSFD) { + return cutlass::make_Coord(m_blks * Blk_MN{} * L, n_blks * Blk_SF{}); + } + else { + return cutlass::make_Coord(m_blks * Blk_SF{} * L, n_blks * Blk_MN{}); + } + }(); + tensor_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D)); + reference_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false); + tensor_SFD.sync_device(); + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + #if 0 + auto [M, N, K, L] = problem_shape_MNKL; + auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + for(int i=0; i(ElementD(ref(i, j, l))) != static_cast((ElementD(comp(i, j, l))))) { + printf(" ref: %f comp: %f\n", i, j, l, static_cast(ElementD(ref(i, j, l))), static_cast((ElementD(comp(i, j, l))))); + } + } + } + } + #endif + std::cout<<"D is incorrect"<(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + Arguments arguments = + { + {}, + tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d + }; + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = tensor_Aux.device_data(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.alpha_ptr = alpha.device_data(); + // Only initializing beta/beta_ptr for non-void source + if constexpr (not cute::is_void_v) { + fusion_args.beta = beta.at(coord_0); + fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr + } + + if constexpr (IsPerRowScaleEnabled) { + int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + fusion_args.dBeta = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + } + else if constexpr (IsPerColScaleEnabled) { + int32_t n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + fusion_args.dBeta = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + } + else { + if constexpr (not IsFfma2Kernel) { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + if (L > 1) { + fusion_args.dAlpha = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + fusion_args.dBeta = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + } + } + } + } + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + auto init_activation_args = [] (auto activation, auto& args) { + using Activation = cute::remove_cvref_t; + if constexpr (cute::is_same_v>) { + args.lower_bound = 0; // Treat Clamp as ReLU + args.upper_bound = cutlass::platform::identity_for_minimum(); + } + if constexpr (cute::is_same_v>) { + args.scale = ElementCompute(1); + } + }; + + if constexpr (not cute::is_same_v>) { + init_activation_args(ActivationFunctor{}, fusion_args.activation); + } + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + + if constexpr (IsBlockScaleSupported) { + arguments.thread.block_scale_factor_ptr = tensor_SFD.device_data(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(IsRowBiasEnabled ? M : N))); + auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); + auto Valpha = [&](){ + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto Vbeta = [&]() { + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(reference_SFD.host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor, + decltype(SfD), + Int, + cutlass::plus, + IsColBiasEnabled + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) + { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + epilogue_params.Valpha = Valpha; + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::initialize(problem_size, alpha, beta)"); +#endif + collective_mma_inputs.initialize(problem_size); + collective_epilogue.initialize(problem_size, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) + { + auto [M, N, K, L] = problem_shape_MNKL; + + bool passed = collective_mma_inputs.compare_reference(problem_shape_MNKL); + passed &= collective_epilogue.compare_reference(problem_shape_MNKL, alpha, beta); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file); + collective_epilogue.print_tensors(file); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_size, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); + auto epilogue_params = collective_epilogue.to_host_args(problem_size); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + bool passed = compare_reference(problem_shape_MNKL, alpha, beta); + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + bool profile( + ProblemShapeType problem_size, + int iterations, + Gemm& gemm_op, + typename Gemm::Arguments& arguments, + cutlass::device_memory::allocation& workspace) { + int M = cute::size<0>(problem_size); + int N = cute::size<1>(problem_size); + int K = cute::size<2>(problem_size); + int L = 1; + if constexpr(cute::rank(ProblemShapeType{}) == 4) { + L = cute::size<3>(problem_size); + } + + + cutlass::Status status; + // + // Run the GEMM + // + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + status = gemm_op(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + return false; + } + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + detail::Iterations iterations = detail::Iterations{}, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run"); +#endif + + // Fail test if insufficient CUDA device + if (!sufficient()) { + CUTLASS_TRACE_HOST("TestbedImpl::run: Test failed due to insufficient CUDA device"); + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: sufficient() returned true"); + } +#endif + + try { + const bool initialized = this->initialize(problem_size, alpha, beta); + if (not initialized) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize returned false"); + std::cerr << "Initialization failed \n"; + return false; + } + } + catch ([[maybe_unused]] std::exception const& e) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize() returned true"); +#endif + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = this->sm_count; + } + else { + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(); + + + if constexpr (IsRuntimeDataType) { + mainloop_args.runtime_data_type_a = runtime_input_datatype_a; + mainloop_args.runtime_data_type_b = runtime_input_datatype_b; + } + + + arguments = + { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + mainloop_args, + collective_epilogue.to_args(problem_size), + hw_info, + scheduler_args + }; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Creating gemm_op"); +#endif + Gemm gemm_op; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling Gemm::get_workspace_size"); +#endif + size_t workspace_size = Gemm::get_workspace_size(arguments); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Allocating workspace of size " << workspace_size); +#endif + cutlass::device_memory::allocation workspace(workspace_size); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.can_implement"); +#endif + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + std::cerr << "This test is not supported: " << error_str << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling profile"); +#endif + return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); + } + else { + cudaError_t result; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.initialize"); +#endif + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.run"); +#endif + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling cudaDeviceSynchronize"); +#endif + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaDeviceSynchronize reports non-success"); + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling this->verify"); +#endif + bool passed = this->verify(problem_size, alpha, beta); + if (!passed) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify FAILED"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify passed"); + } +#endif + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Reached end"); +#endif + return passed; + } + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + , RuntimeDatatypeA + , RuntimeDatatypeB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + bool profiling = false, + detail::Iterations iterations = detail::Iterations{} + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { + return impl_.run( + problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode + , runtime_input_datatype_a, runtime_input_datatype_b + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestGemmPerf3x(int iterations = 20) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalar = ElementAccumulator; + bool passed = true; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector problem_size_m = { 4608 }; + std::vector problem_size_n = { 4608 }; + std::vector problem_size_k = { 8192 }; + + Testbed3x testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0), + RasterOrderOptions{}, detail::MaxSwizzleSize(1), detail::Splits{1}, DecompositionMode{}, + true, // profiling + detail::Iterations{iterations}); + + if (!passed) { + return false; + } + } + } + } + + return true; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +template < + typename Gemm, + typename RuntimeDataTypeA, + typename RuntimeDataTypeB, + bool force_legacy_epilogue = false> +bool TestRuntimeDataTypeSmall( + RuntimeDataTypeA runtime_input_datatype_a, + RuntimeDataTypeB runtime_input_datatype_b, + double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + + CtaShape_MNK cta_shape; + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment * MultiplierOffsetK, 512 + max_alignment * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + [[maybe_unused]] constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment; + + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + } + for (auto splits : problem_splits) { + + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + else { + std::cout << "Unsupported configuration for runtime datatype MXFP4." << std::endl; + return false; + } + } + + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + static_assert((cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v), + "Runtime datatype must be selected with an appropriate static umbrella data type."); + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // f6xf4 + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e3m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E3M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m3 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M3) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e5m2_e5m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e5m2 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e5m2_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for runtime datatype."); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmall(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment_k = 0; + int max_alignment_m = 0; + int max_alignment_n = 0; + + if constexpr (apply_alignment_offset) { + max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_n = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_m = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + } + // Alignment for SFD + if constexpr (detail::IsSfdEpi::value) { + using GmemLayoutTagScalefactor = typename Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::GmemLayoutTagScalefactor; + constexpr int SFDVecSize = Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::SFVecSize; + if constexpr (cute::is_same_v) { + max_alignment_n = std::lcm(max_alignment_n, SFDVecSize); + } + else { + max_alignment_m = std::lcm(max_alignment_m, SFDVecSize); + } + } + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment_k * MultiplierOffsetK, 512 + max_alignment_k * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + std::vector raster_order_options = {RasterOrderOptions::Heuristic}; + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment_m; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment_n; + int l = test_batched_alpha_beta && wave == waves[0] && k == problem_size_k[0] ? 2 : 1; // only test the smallest problem size + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, l}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + for (RasterOrderOptions raster_order : raster_order_options) { + std::vector problem_splits = {detail::Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{4}); + } + } + for (auto splits : problem_splits) { + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, // raster_order + detail::MaxSwizzleSize(0), + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + EXPECT_TRUE(passed) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << l << " FAILED.\n"; + return false; + } + } // splits + } // raster_order + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + return TestSmall(alpha, + beta, + check_relative_equality, + use_device_scalars, + vector_scale_mode, + override_problem_size_k); +} + + + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); + + int max_alignment_m = std::max({Gemm::kAlignmentA, Gemm::kAlignmentC, Gemm::kAlignmentD}); + int max_alignment_n = std::max({Gemm::kAlignmentB, Gemm::kAlignmentC, Gemm::kAlignmentD}); + if constexpr (std::is_base_of_v) { + max_alignment_m = std::max(max_alignment_m, Gemm::EpilogueOutputOp::AlignmentAux); + max_alignment_n = std::max(max_alignment_n, Gemm::EpilogueOutputOp::AlignmentAux); + } + std::vector problem_size_m = {max_alignment_m, 512 - 3 * max_alignment_m}; + std::vector problem_size_n = {max_alignment_n, 512 - 2 * max_alignment_n}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + int max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_k = {max_alignment_k, TileShapeK * (Stages + 1) - max_alignment_k}; + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + std::vector problem_splits = {detail::Splits{1}}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{3}); + + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + + // Use larger K sizes for stream-K tests + static constexpr int min_tiles_per_sk_unit = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::min_iters_per_sk_unit_; + problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit - max_alignment_k}; + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + std::vector raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN}; + std::vector max_swizzle_sizes{detail::MaxSwizzleSize{1}, detail::MaxSwizzleSize{4}}; + + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (auto raster_order : raster_orders) { + for (auto max_swizzle_size : max_swizzle_sizes) { + for (DecompositionMode decomp_mode : decomposition_modes) { + + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + auto max_splits = (k + TileShapeK - 1) / TileShapeK; + if (max_splits > 2) { + problem_splits.push_back(detail::Splits{2}); + } + if (max_splits > 3) { + problem_splits.push_back(detail::Splits{3}); + } + + problem_splits.push_back(detail::Splits{max_splits}); + + // Test the case in which we ask for more splits than there are K tiles in the GEMM. In this + // case, split-K will fall back to a splitting factor of `max_splits`. + problem_splits.push_back(detail::Splits{max_splits + 1}); + } + for (auto splits : problem_splits) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, + max_swizzle_size, + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + + EXPECT_TRUE(passed) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // max_swizzle_size + } // raster_order + } // k + } // n + } // m + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment_m, 256 + max_alignment_n, 160 + max_alignment_k, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +template +bool TestAllBiasElementwise(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { + return TestAll(alpha, beta, check_relative_equality); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f18a7b39cbfe7dfb8d3251b2750e49261522de8a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -0,0 +1,1742 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed and host reference for EVT unittest +*/ + + +#pragma once +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +/// Host-side tapply, tapply in cute is HOST_DEVICE +template +constexpr auto +tapply(T&& t, F&& f, G&& g, cute::seq) +{ + return g(f(std::get(static_cast(t)))...); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT: Base class for EVT Node + +template < class ElementCompute_ > +class HostEVTNodeBase { +public: + using ElementCompute = ElementCompute_; + +private: + bool check_relative_equality_; + // Factors used for calculating relative equality. These default + // values are borrowed from those used by default in the CUTLASS + // profiler for performing relative equality checks. + float epsilon_ = 0.05f; + float nonzero_floor_ = 1.0f / 256.0f; + +public: + HostEVTNodeBase(){} + HostEVTNodeBase(bool check_relative_equality): + check_relative_equality_(check_relative_equality) { } + + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + if (check_relative_equality_) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, Element(epsilon_), Element(nonzero_floor_) + ); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + void* get_tensor_C_ptr() { + return nullptr; + } + + void* get_tensor_D_ptr() { + return nullptr; + } + + bool compare_reference(std::stringstream& error_ss) { + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Accumulator + +template< class ElementCompute = float > +class HostAccumulator: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + struct Arguments { }; + +public: + HostAccumulator(){} + template + HostAccumulator(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality) {} + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + cutlass::NumericConverter accumulator_converter; + return accumulator_converter(acc); + } + + Arguments get_arguments() { + return Arguments{}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Broadcast + +template < + int Value, + int BroadcastCount = 1, + class StrideMNL = cute::Stride, + template class ReductionFn = cutlass::multiplies, + class ElementCompute = float +> +class HostScalarBroadcast : public HostEVTNodeBase { +public: + + using Base = HostEVTNodeBase; + struct Arguments { + ElementCompute scalar[BroadcastCount] = {0}; + ElementCompute const* scalar_ptrs[BroadcastCount] = { nullptr }; + StrideMNL dScalar[BroadcastCount] = {}; + }; +private: + ElementCompute scalar_{}; + StrideMNL dScalar{}; + ElementCompute scalar_reduced_{}; +public: + HostScalarBroadcast(){} + + template + HostScalarBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality), scalar_(ElementCompute(Value)) { + scalar_ = ElementCompute(Value); + scalar_reduced_ = scalar_; + for (int i = 1; i < BroadcastCount; ++i) { + scalar_reduced_ = ReductionFn{}(scalar_reduced_, ElementCompute(Value)); + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return scalar_reduced_; + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss << "Scalar: " << float(scalar_) << "\n\n"; + return true; + } + + Arguments get_arguments() { + if constexpr (BroadcastCount == 1) + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + else if constexpr (BroadcastCount == 2) + return Arguments{{scalar_, scalar_}, {nullptr, nullptr}, {dScalar, dScalar}}; + else if constexpr (BroadcastCount == 3) + return Arguments{{scalar_, scalar_, scalar_}, {nullptr, nullptr, nullptr}, {dScalar, dScalar, dScalar}}; + else + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + } + + auto get_flatten_arguments() { + if constexpr (BroadcastCount == 1) { + return cute::make_tuple(scalar_, nullptr); + } + else if constexpr (BroadcastCount == 2) { + return cute::make_tuple(scalar_, scalar_, nullptr, nullptr); + } + else if constexpr (BroadcastCount == 3) { + return cute::make_tuple(scalar_, scalar_, scalar_, nullptr, nullptr, nullptr); + } + else { + return cute::make_tuple(scalar_, nullptr); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostRowBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int N_; +public: + HostRowBroadcast(){} + template + HostRowBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(N_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + + return bias_converter_(TensorBias(1, n + n_b)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerColumnBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostColBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int M_; +public: + HostColBroadcast(){} + template + HostColBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(M_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + + return bias_converter_(TensorBias(m + m_b, 1)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerRowBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Load + +template < + typename ElementAuxLoad_, + typename LayoutTagAux_, + bool isC = false, + typename ElementCompute = float +> +class HostAuxLoad: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementAuxLoad = ElementAuxLoad_; + using LayoutTagAux = LayoutTagAux_; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + ElementAuxLoad const *ptr_aux = nullptr; + ElementAuxLoad null_default = ElementAuxLoad(0); + StrideAux dAux = {}; + }; + + struct Arguments_C {}; + + using Arguments = cute::conditional_t; + +private: + cutlass::NumericConverter aux_load_converter_; + cutlass::HostTensor tensor_aux_load_; + + int M_, N_, L_; + + StrideAux stride_aux_; +public: + HostAuxLoad(){} + template + HostAuxLoad(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_NMKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_NMKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_load_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + EXPECT_TRUE( + detail::initialize_tensor( + tensor_aux_load_.host_view(), + cutlass::Distribution::Uniform, + seed + ) + ); + tensor_aux_load_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + + auto TensorAuxLoad = cute::make_tensor(tensor_aux_load_.host_data(), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + return aux_load_converter_(TensorAuxLoad(m + m_b, n + n_b, l)); + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (!isC) { + error_ss + << "AuxLoad = \n" << tensor_aux_load_.host_view()<< "\n\n"; + } + return true; + } + + void* get_tensor_C_ptr() { + if constexpr (isC) { + return static_cast(tensor_aux_load_.device_data()); + } + else { + return nullptr; + } + } + + Arguments get_arguments() { + if constexpr (isC) + return {}; + else + return {tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_}; + } + + auto get_flatten_arguments() { + if constexpr (isC) + return cute::make_tuple(); + else + return cute::make_tuple(tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Compute + +template +T* findNonNullPtr(T* first_ptr) { + return first_ptr; +} + +template +T* findNonNullPtr(T* first_ptr, Args... args) { + if (first_ptr) { + return first_ptr; + } + return findNonNullPtr(args...); +} + +template < + template class ComputeOp_, + typename ElementCompute = float +> +class HostCompute: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ComputeOp = ComputeOp_; + + struct Arguments { + struct OpArgs {} op; + }; +private: + ComputeOp op_; +public: + HostCompute(){} + template + HostCompute(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, Args... frg_inputs) { + return op_(frg_inputs...); + } + + Arguments get_arguments(){ + return {}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Store + +template < + class ElementAuxStore_, + typename LayoutTagAux_, + bool isD = false, + bool isRelu = false, + typename ElementCompute = float +> +class HostAuxStore: public HostEVTNodeBase { +public: + using ElementAuxStore = ElementAuxStore_; + using LayoutTagAux = LayoutTagAux_; + + using Base = HostEVTNodeBase; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + struct OpArgs { + ElementAuxStore* ptr_aux = nullptr; + StrideAux dAux = {}; + } op; + }; + + struct Arguments_D {}; + + using Arguments = cute::conditional_t; + + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_aux_store_; + cutlass::HostTensor reference_aux_store_; + int M_, N_, L_; + StrideAux stride_aux_; +public: + HostAuxStore(){} + template + HostAuxStore(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_MNKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + + reference_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + tensor_aux_store_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + + auto TensorAuxStore = cute::make_tensor(detail::make_iterator(static_cast(reference_aux_store_.host_data())), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + if constexpr (isRelu) + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result >= 0); + else + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_aux_store_.sync_host(); + + bool equal = this->equality_check(reference_aux_store_.host_view(), tensor_aux_store_.host_view()); + if (!equal) { + error_ss + << "\n\nReference =\n" << reference_aux_store_.host_view() + << "\n\nComputed =\n" << tensor_aux_store_.host_view() << "\n\n"; + } + return equal; + } + + void* get_tensor_D_ptr() { + if constexpr (isD) + return static_cast(tensor_aux_store_.device_data()); + else + return nullptr; + } + + Arguments get_arguments() { + if constexpr (isD) { + return {}; + } + else { + return {tensor_aux_store_.device_data(), stride_aux_}; + } + } + + auto get_flatten_arguments() { + if constexpr (isD) { + return cute::make_tuple(); + } + else { + return cute::make_tuple(tensor_aux_store_.device_data(), stride_aux_); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostRowReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_row = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_row_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_row_reduce_; + int N_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostRowReduce(){} + template + HostRowReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + if constexpr (FinalReduction) { + tensor_row_reduce_.resize(cutlass::Coord<1>(N_)); + reference_row_reduce_.resize(cutlass::Coord<1>(N_)); + reduce_buffer_.resize(cutlass::Coord<1>(N_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile); + extent_n_ = cute::get<1>(NumTile) * TileN; + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_row_reduce_.resize(shape); + reference_row_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + if constexpr (FinalReduction) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + TensorRowReduce(1, n + n_b) = reduce_fn_(TensorRowReduce(1, n + n_b), child_0_result); + } + else { + auto TensorRowReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(extent_n_, 1, extent_m_ * extent_l_) + ) + ); + TensorRowReduce((m+m_b)/TileM, n+n_b, l) = reduce_fn_(TensorRowReduce((m+m_b)/TileM, n+n_b, l), child_0_result); + } + + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_row_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_row_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_row_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t n = 0; n < size(TensorRowReduce); n ++) { + TensorRowReduce(n) = destination_converter_(TensorReduceBuffer(n)); + } + + bool equal = this->equality_check(reference_row_reduce_.host_view(), tensor_row_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nRow Reduce Reference =\n" << reference_row_reduce_.host_view() + << "\n\nRow Reduce Computed =\n" << tensor_row_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_row_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostColumnReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_col = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_column_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_column_reduce_; + int M_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostColumnReduce(){} + template + HostColumnReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + + if constexpr (FinalReduction) { + tensor_column_reduce_.resize(cutlass::Coord<1>(M_)); + reference_column_reduce_.resize(cutlass::Coord<1>(M_)); + reduce_buffer_.resize(cutlass::Coord<1>(M_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile) * TileM; + extent_n_ = cute::get<1>(NumTile); + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_column_reduce_.resize(shape); + reference_column_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorColReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + if constexpr (FinalReduction) { + TensorColReduce(m + m_b, 1) = reduce_fn_(TensorColReduce(m + m_b, 1), child_0_result); + } + else { + auto shape = reduce_buffer_.extent(); + auto TensorColReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(1, extent_m_, extent_m_ * extent_l_) + ) + ); + TensorColReduce(m+m_b, (n+n_b)/TileN, l) = reduce_fn_(TensorColReduce(m+m_b, (n+n_b)/TileN, l), child_0_result); + } + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_column_reduce_.sync_host(); + + auto TensorColReduce = cute::make_tensor(reference_column_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_column_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t m = 0; m < size(TensorColReduce); m ++) { + TensorColReduce(m) = destination_converter_(TensorReduceBuffer(m)); + } + + bool equal = this->equality_check(reference_column_reduce_.host_view(), tensor_column_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nColumn Reduce Reference =\n" << reference_column_reduce_.host_view() + << "\n\nColumn Reduce Computed =\n" << tensor_column_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_column_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Reduce + +template < + template class ReduceFn, + typename ElementReduce, + typename ElementCompute = float, + bool enabled = true +> +class HostScalarReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_scalar = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dScalar = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_scalar_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_scalar_reduce_; + ReduceFn reduce_fn_; +public: + HostScalarReduce(){} + template + HostScalarReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + tensor_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reference_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reduce_buffer_.resize(cutlass::Coord<1>(1)); + + tensor_scalar_reduce_.sync_device(); + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + TensorRowReduce(0) = reduce_fn_(TensorRowReduce(0), child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (enabled) { + // Verify the store node + tensor_scalar_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_scalar_reduce_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + // Filling the reference tensor with the reduce buffer + TensorRowReduce(0) = destination_converter_(TensorReduceBuffer(0)); + + bool equal = this->equality_check(reference_scalar_reduce_.host_view(), tensor_scalar_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nScalar Reduce Reference =\n" << reference_scalar_reduce_.host_view() + << "\n\nScalar Reduce Computed =\n" << tensor_scalar_reduce_.host_view() << "\n\n"; + } + return equal; + } + else { + return true; + } + + } + + Arguments get_arguments() { + return {tensor_scalar_reduce_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(tensor_scalar_reduce_.device_data()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Host EVT wrapper + +/// The ArgumentPack is used to model the alignment when num ops <= 4 +template +struct ArgumentPack; + +template +struct ArgumentPack { + T arg; + ArgumentPack(T first): + arg(first) {} +}; + +template +struct ArgumentPack { + First arg; + ArgumentPack rest_args; + + ArgumentPack(First first, Rest... rest) : + arg(first), rest_args(rest...) {} +}; + + +/// Base class for Host Visitor +template +struct HostVisitorBase: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + using Arguments_struct = ArgumentPack; + using Arguments_tuple = cute::tuple; + + constexpr static int Rm1 = sizeof...(Ops); + constexpr static bool cond = Rm1 > 4; + using Arguments = cute::conditional_t; + + std::tuple ops; + + HostVisitorBase(){} + template + HostVisitorBase(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality), + ops(test::gemm::device::tapply(std::tuple{}, + [&] (auto&& op) { + using Op = cute::remove_cvref_t; + return Op(problem_size, check_relative_equality, seed); + }, + [] (auto&&... _ops) { + return std::make_tuple(_ops...); + }, + cute::make_seq{} + )){ } + + bool compare_reference(std::stringstream& error_ss) { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.compare_reference(error_ss); + }, + [&] (auto&&... inputs) { + return arrayAnd(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_C_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_C_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_D_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_D_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + Arguments get_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_arguments(); + }, + [&] (auto&&... args) { + if constexpr (Rm1 > 4) { + return cute::make_tuple(args...); + } + else { + return Arguments(args...); + } + }, + cute::make_seq{} + ); + } + + auto get_flatten_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_flatten_arguments(); + }, + [&] (auto&&... args) { + return flatten(cute::make_tuple(args...)); + }, + cute::make_seq{} + ); + } + + bool arrayAnd(bool passed) { + return passed; + } + + template + bool arrayAnd(bool first_passed, Args... passed) { + if (first_passed) { + return arrayAnd(passed...); + } + return first_passed; + } + +}; + + +/// Tree-struct visitor +template +struct HostTreeVisitor: public HostVisitorBase { +public: + using ElementCompute = typename NodeOp::Base::ElementCompute; + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm1 = sizeof...(ChildOps); + + HostTreeVisitor(){} + template + HostTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed){ } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + return cute::detail::tapply(this->ops, + [&] (auto& op) { + return op.visit(m, n, l, m_b, n_b, acc); + }, + [&] (auto&&... frg_inputs) { + return std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + }, + cute::make_seq{} + ); + } +}; + + +/// General Graph visitor +template +struct HostTopoVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + constexpr static int Rm1 = Base::Rm1; + using Arguments = typename Base::Arguments; + +private: + ElementCompute frg_outputs_[Rm1]; +public: + HostTopoVisitor(){} + template + HostTopoVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + ElementCompute visit_( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + frg_outputs_[I] = cute::transform_apply(cute::get(EdgeTuple{}), + [&] (auto&& _E) { + constexpr int e = cute::remove_cvref_t::value; + return frg_outputs_[e]; + }, + [&] (auto const&... frg_inputs) { + ElementCompute res = std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + return res; + } + ); + + if constexpr (I < Rm1 - 1) { + return visit_(m, n, l, m_b, n_b, acc); + } + else { + return frg_outputs_[I]; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return visit_(m, n, l, m_b, n_b, acc); + } + +}; + + +/// SplitTree visitor +template +struct HostSplitTreeVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm2 = sizeof...(AuxOutTrees); + +private: + ElementCompute frg_input_; +public: + HostSplitTreeVisitor(){} + template + HostSplitTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + void visitAux( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator frag) { + std::get(this->ops).visit(m, n, l, m_b, n_b, frag); + + if constexpr (I < Rm2 - 1) { + return visitAux(m, n, l, m_b, n_b, frag); + } + else { + return; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + /// Compute the input tree + frg_input_ = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc); + + /// Compute the aux out tree + visitAux(m, n, l, m_b, n_b, frg_input_); + /// Visit the output tree + return std::get(this->ops).visit(m, n, l, m_b, n_b, frg_input_); + } +}; + +/// Universal testbed for EVT w/o smem +template +class Testbed3xEVTnoSmem { +public: + // The EVT Module to test + using EVTModule = EVT; //typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // + // Methods + // + Testbed3xEVTnoSmem( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVTnoSmem( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view(); + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + int iterations = 20, + bool profiling = false) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + {}, + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + if constexpr (FlatArgs) { + auto epilogue_args = host_reference.get_flatten_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args, sizeof(epilogue_args)); + + arguments.epilogue.ptr_C = static_cast(host_reference.get_tensor_C_ptr()); + arguments.epilogue.dC = impl_.collective_epilogue.stride_c; + + arguments.epilogue.ptr_D = static_cast(host_reference.get_tensor_D_ptr()); + arguments.epilogue.dD = impl_.collective_epilogue.stride_d; + } + else { + auto epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue, &epilogue_args, sizeof(epilogue_args)); + } + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +/// Universal testbed for EVT +template +class Testbed3xEVT { +public: + // The EVT Module to test + using EVTModule = typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + + // + // Methods + // + Testbed3xEVT( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVT( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + Testbed3xEVT( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n"; + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + bool profiling = false, + int iterations = 20, + int splits = 1) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { splits }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + { // Epilogue arguments + {}, // thread + static_cast(host_reference.get_tensor_C_ptr()), + impl_.collective_epilogue.stride_c, + static_cast(host_reference.get_tensor_D_ptr()), + impl_.collective_epilogue.stride_d + }, // Epilogue arguments end + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + typename EVTModule::Arguments epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg)); + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +template +bool TestAllEVT(bool check_relative_equality = false) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xEVT testbed(check_relative_equality); + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run(problem_size); + + if (!passed) { + return false; + } + } + } + } + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbc54ec582d88d9039968d8153cf6127a06ec274 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -0,0 +1,2409 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed for Ptr-Array and Grouped GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail{ + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + std::vector stride_a_host; + std::vector stride_b_host; + + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + // for pointer array problem_shapes.groups() is 1 + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for(int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + Arguments arguments; + + if constexpr (IsGroupGemm) { + arguments + = + { + device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() + }; + } + else { + arguments = + { + device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + + bool passed = true; + return passed; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using InternalLayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using InternalLayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + std::vector stride_a_host; + std::vector stride_b_host; + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + std::vector layout_sfa_host; + std::vector layout_sfb_host; + cutlass::DeviceAllocation layout_sfa_device; + cutlass::DeviceAllocation layout_sfb_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + std::vector> tensors_SFA; + std::vector> tensors_SFB; + + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + cutlass::DeviceAllocation device_tensors_SFA; + cutlass::DeviceAllocation device_tensors_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + tensors_SFA.clear(); + tensors_SFB.clear(); + layout_sfa_host.clear(); + layout_sfb_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + + using namespace cute; + + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1))); + layout_sfb_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1))); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{}, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{}, k_blks * Blk_SF{}); + + tensors_SFA.push_back(cutlass::HostTensor(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A))); + tensors_SFB.push_back(cutlass::HostTensor(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_SFA[i].host_view(), init_A, seed + 2024 + i)); + EXPECT_TRUE(initialize_tensor(tensors_SFB[i].host_view(), init_B, seed + 2025 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_SFA[i].host_view().at({0, 0}) = ElementSF(1); + tensors_SFB[i].host_view().at({0, 0}) = ElementSF(1); + + tensors_SFA[i].sync_device(); + tensors_SFB[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + std::vector ptr_SFA_host(L); + std::vector ptr_SFB_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + ptr_SFA_host.at(i) = tensors_SFA[i].device_data(); + ptr_SFB_host.at(i) = tensors_SFB[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + device_tensors_SFA.reset(L); + device_tensors_SFA.copy_from_host(ptr_SFA_host.data()); + + device_tensors_SFB.reset(L); + device_tensors_SFB.copy_from_host(ptr_SFB_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + layout_sfa_device.reset(problem_shapes.groups()); + layout_sfa_device.copy_from_host(layout_sfa_host.data()); + + layout_sfb_device.reset(problem_shapes.groups()); + layout_sfb_device.copy_from_host(layout_sfb_host.data()); + + if constexpr (IsGroupGemm) { + return Arguments{ + device_tensors_A.get(), stride_a_device.get(), + device_tensors_B.get(), stride_b_device.get(), + device_tensors_SFA.get(), layout_sfa_device.get(), + device_tensors_SFB.get(), layout_sfb_device.get() + }; + } + else { + return Arguments{ + device_tensors_A.get(), stride_a_host[0], + device_tensors_B.get(), stride_b_host[0], + device_tensors_SFA.get(), layout_sfa_host[0], + device_tensors_SFB.get(), layout_sfb_host[0] + }; + } + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto SfA = make_tensor(tensors_SFA[batch].host_data(), layout_sfa_host[batch]); + + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + auto SfB = make_tensor(tensors_SFB[batch].host_data(), layout_sfb_host[batch]); + + return cutlass::reference::host::GettMainloopParams + {A, SfA, B, SfB}; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view() + << "\nSFA =\n" << tensors_SFA[batch].host_view() + << "\nSFB =\n" << tensors_SFB[batch].host_view(); + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFA[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFB[batch].host_view()), 0); + return true; + } +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + ElementScalar alpha; + ElementScalar beta; + + std::vector> tensors_C; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_C; + cutlass::DeviceAllocation device_tensors_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M, N); + + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + L = std::max(problem_shapes.groups(), L); + + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool IsLegacy = + cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> + >; + + using FusionOp = typename Gemm::EpilogueOutputOp; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + using ElementSFD = non_void_t, ElementD>; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + SFD_VectorSize + >; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + std::vector> tensors_SFD; + std::vector> references_SFD; + cutlass::DeviceAllocation device_tensors_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + std::vector> tensors_C; + cutlass::DeviceAllocation device_tensors_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + std::vector> tensors_Aux; + cutlass::DeviceAllocation device_tensors_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_D; + + // References + cutlass::HostTensor reference_dbias; + std::vector> references_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + tensors_SFD.clear(); + references_SFD.clear(); + + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + auto c_coord = cutlass::make_Coord(M, N); + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + if constexpr (IsPerRowScaleEnabled) { + alpha.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (vector_scale_mode == VectorScale::DISABLED) { + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + else { + beta.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + } + else { + alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsBiasEnabled) { + bias.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + tensors_Aux.clear(); + references_Aux.clear(); + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxInEnabled)); + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + for (int32_t i = 0; i < L; ++i) { + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + EXPECT_TRUE(initialize_tensor(tensors_Aux[i].host_view(), init_C, seed + 2023)); + tensors_Aux[i].sync_device(); + } + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + } + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); + + if constexpr (IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + references_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout, false)); + tensors_Aux[i].sync_device(); + } + + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, _] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + // If block scaled output is supported we always have at least 1 SFD + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + return cutlass::make_Coord(m_blks * Blk_MN{}, n_blks * Blk_SF{}); + }(); + tensors_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D))); + references_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false)); + tensors_SFD[i].sync_device(); + } + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + std::vector ptr_Aux_host(L); + if constexpr (IsAuxInEnabled || IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + ptr_Aux_host.at(i) = tensors_Aux[i].device_data(); + } + device_tensors_Aux.reset(L); + device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); + } + + auto device_tensors_C_ptr = cute::is_void_v ? nullptr : + reinterpret_cast(device_tensors_C.get()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = device_tensors_Aux.get(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.beta = beta.at(coord_0); + + fusion_args.alpha_ptr = alpha.device_data(); + // can_implement requires beta_ptr to not be set if its voidC + fusion_args.beta_ptr = cute::is_void_v ? nullptr : + beta.device_data(); + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { + fusion_args.activation.scale = ElementCompute(1); + } + + // Treat Clamp as ReLU + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = 0; + fusion_args.activation.upper_bound = std::numeric_limits::max(); + } + + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + if constexpr (IsBlockScaleSupported) { + std::vector ptr_SFD_host(L); + for (int32_t i = 0; i < L; ++i) { + ptr_SFD_host.at(i) = tensors_SFD[i].device_data(); + } + device_tensors_SFD.reset(L); + device_tensors_SFD.copy_from_host(ptr_SFD_host.data()); + + arguments.thread.block_scale_factor_ptr = device_tensors_SFD.get(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto [M, N, K, L] = problem_shape_MNKL; + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); + auto Aux = [&]() { + auto ptr = recast_ptr(nullptr); + if (IsAuxInEnabled) { + ptr = detail::make_iterator(tensors_Aux[batch].host_data()); + } else if (IsAuxOutEnabled) { + ptr = detail::make_iterator(references_Aux[batch].host_data()); + } + return cute::make_tensor(ptr, Aux_layout); + }(); + auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(references_SFD[batch].host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor + , decltype(SfD) + , Int + , cutlass::plus + , false + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsBiasEnabled or IsDeBiasEnabled) { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + static constexpr bool IsGroupGemm = CollectiveEpilogue::IsGroupGemm; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + collective_mma_inputs.initialize(problem_shapes); + collective_epilogue.initialize(problem_shapes, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) + { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + + bool passed = collective_mma_inputs.compare_reference(problem_shapes, batch); + passed &= collective_epilogue.compare_reference(problem_shapes, alpha, beta, batch); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << batch << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << batch + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file, batch); + collective_epilogue.print_tensors(file, batch); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + bool passed = true; + for (int32_t i = 0; i < L; ++i) { + auto mainloop_params = collective_mma_inputs.to_host_args(problem_shapes, i); + auto epilogue_params = collective_epilogue.to_host_args(problem_shapes, i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + passed &= compare_reference(problem_shapes, alpha, beta, i); + } + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + + // Fail test if insufficient CUDA device + if (!sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + + if (!this->initialize(problem_shapes, alpha, beta)) { + std::cerr << "Initialization failed \n"; + return false; + } + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(problem_shapes); + + if constexpr (IsGroupGemm) { + arguments = + { + cutlass::gemm::GemmUniversalMode::kGrouped, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + else { + arguments = + { + cutlass::gemm::GemmUniversalMode::kArray, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return false; + } + + // + // Run the GEMM + // + + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_shapes, alpha, beta); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } + + return passed; + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + static constexpr bool IsGroupGemm = TestBedImpl::IsGroupGemm; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + return impl_.run( + problem_shapes, alpha, beta, iterations); + } +}; + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_DEVICE, VectorScale::DISABLED); + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + int batches[] = {5, 10}; + + bool passed = true; + + for (int batch : batches) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)}); + } + + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + passed = testbed.run( + ProblemShapeType{static_cast(problem_sizes_host.size()), problem_sizes_device.get(), problem_sizes_host.data()}, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_size{{m, n, k, batch}}; + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << batch << " FAILED.\n"; + return false; + } + } // k + } // n + } // m + } // batch + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSmall(double alpha = 1.0, double beta = 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ElementA = typename Gemm::GemmKernel::ElementA; + using ElementB = typename Gemm::GemmKernel::ElementB; + using TiledMma = typename Gemm::GemmKernel::TiledMma; + + static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); + // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. + int alignment_bits_a = cutlass::detail::get_input_alignment_bits(); + int alignment_input_a = (alignment_bits_a / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits::value); + + int alignment_bits_b = cutlass::detail::get_input_alignment_bits(); + int alignment_input_b = (alignment_bits_b / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits::value); + + int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b); + + if constexpr (apply_alignment_offset) { + // If BlockScaled, then min alignment is SFVecSize + static constexpr bool IsBlockScaleSupported = Gemm::EpilogueOutputOp::IsBlockScaleSupported; + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + if constexpr (IsBlockScaleSupported) { + alignment_input = cutlass::round_up(alignment_input, SFVecSize); + } + } + + + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + // For Ptr-Array and Grouped GEMM ideally we need to know SM count at runtime + static constexpr int SmCount = 16; + + float waves[] = {0.5, 2.5}; + int batches[] = {3}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + // this is to test with min alignment + problem_size_k = {256 - alignment_input, 512 + alignment_input}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + bool passed = true; + + for (int batch : batches) { + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + float num_grid = wave * SmCount; + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = static_cast(num_grid) / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = static_cast(num_grid) / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) - alignment_input; // this is just to test with unusual problem shapes + int n = grid_n * cute::size<1>(cta_shape) + alignment_input; + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 2) + 1), n * ((i % 3) + 1), k * ((i % 2) + 1)}); + } + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + ProblemShapeType problem_shapes{batch, problem_sizes_device.get(), problem_sizes_host.data()}; + + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + for (int i = 0; i < batch; ++i) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape(i) << " \n"; + } + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_shapes{{m, n, k, batch}}; + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape() << " \n"; + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // k + } // waves + } // batches + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = 0.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED) { + return TestSmall( + alpha, beta, check_relative_equality, use_device_scalars, vector_scale_mode); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b00f98a97846de175f1c6f95919c483ab4b81da --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -0,0 +1,515 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with elementwise tensor-tensor broadcast epilogue +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed_utils.h" +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed3xTensorBroadcast { + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementA = typename Kernel::ElementA; + using StrideA = typename Kernel::StrideA; + using ElementB = typename Kernel::ElementB; + using StrideB = typename Kernel::StrideB; + using ElementC = typename Kernel::ElementC; + using StrideC = typename Kernel::StrideC; + using ElementD = typename Kernel::ElementD; + using StrideD = typename Kernel::StrideD; + + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementCompute = typename Epilogue::ElementCompute; + using ElementScalar = typename Epilogue::ElementScalar; + using ProblemShapeType = typename Kernel::ProblemShape; + using ElementBias = typename Epilogue::ElementBias; + using ActivationFunctor = typename Epilogue::ActivationFunctor; + + static constexpr bool IsBinaryOp0Enabled = Epilogue::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled; + + static constexpr bool PerColBias = Epilogue::PerColumnBias; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C1; + // tensor_C0 is taken from TestbedImpl's tensor_C + + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3xTensorBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_) { } + + Testbed3xTensorBroadcast( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_HOST, VectorScale::ENABLED, + init_A_, + init_B_, + init_C_, + cutlass::Distribution::Uniform, + cutlass::Distribution::Uniform, + seed_) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B/C/D tensor + // + impl_.initialize(problem_size); + } + + void initialize_bias(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL); + bias.resize(cutlass::Coord<1>(bias_size)); + + EXPECT_TRUE(detail::initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023)); + bias.sync_device(); + } + + void initialize_c1(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto c_coord = cutlass::make_Coord(M * L, N); + + tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C)); + EXPECT_TRUE(detail::initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2024)); + tensor_C1.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta, + bool use_bias) + { + auto [M, N, K, L] = problem_shape_MNKL; + + impl_.collective_epilogue.tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_B.host_view()), 0); + + if (impl_.collective_epilogue.tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.tensor_D.host_view()), 0); + } + + if (impl_.collective_epilogue.reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.reference_D.host_view()), 0); + } + + bool passed = cutlass::reference::host::TensorEquals(impl_.collective_epilogue.reference_D.host_view(), impl_.collective_epilogue.tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_broadcast" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias + << ", per-col bias: " << PerColBias << "\n\n"; + + if (use_bias){ + file << "Bias = \n" << bias.host_view()<< "\n\n"; + } + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC0 =\n" << impl_.collective_epilogue.tensor_C.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << impl_.collective_epilogue.reference_D.host_view() + << "\n\nComputed =\n" <(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto D = cute::make_tensor(impl_.collective_epilogue.reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), + cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1))); + auto C0 = cute::make_tensor(impl_.collective_epilogue.tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + auto C1 = cute::make_tensor(tensor_C1.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + // Create host workspace for output of testbed. This computes a portion of the epilogue: + // ref_compute_out = Activation(alpha * (A @ B) + bias) + cutlass::HostTensor ref_compute_out; + auto c_coord = cutlass::make_Coord(M * L, N); + ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C), false); + auto RefComputeOut = cute::make_tensor(ref_compute_out.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + // Use a dummy null tensor for operand C because the epilogue overrides C. + auto dummy_C = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + ElementCompute dummy_beta(0); + auto dummy_Aux = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto dummy_Valpha = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + + auto dummy_SFD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + using DummySFDVectorSize = cute::Int<0>; + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(dummy_C), + decltype(RefComputeOut), + decltype(Bias), + decltype(dummy_Aux), + decltype(dummy_Valpha), + decltype(dummy_Vbeta), + ActivationFunctor, + decltype(dummy_SFD), + DummySFDVectorSize, + cutlass::plus, + PerColBias> epilogue_params{ + alpha, + dummy_beta, + dummy_C, + RefComputeOut, + Bias, + dummy_Aux, + dummy_Valpha, + dummy_Vbeta + }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + cutlass::NumericConverter source_converter; + cutlass::NumericConverter destination_converter; + cutlass::multiplies mul; + + // Compute broadcast operations atop the reference + #pragma omp parallel for collapse(3) + for (int64_t l = 0; l < cute::size<2>(A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(A.layout()); ++m) { + for (int64_t n = 0; n < cute::size<0>(B.layout()); ++n) { + ElementCompute intermediate = RefComputeOut(m, n, l); + // Apply BinaryOp0, if needed + if constexpr (IsBinaryOp0Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp0 bin0; + ElementCompute converted_source = source_converter(C0(m, n, l)); + intermediate = bin0(intermediate, mul(beta, converted_source)); + } + + // Apply BinaryOp1, if needed + if constexpr (IsBinaryOp1Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp1 bin1; + ElementCompute converted_source = source_converter(C1(m, n, l)); + intermediate = bin1(intermediate, mul(beta, converted_source)); + } + + // Apply UnaryOp, if needed + if constexpr (IsUnaryOpEnabled) { + typename Epilogue::ThreadEpilogueOp::UnaryOp unary; + intermediate = unary(intermediate); + } + + D(m, n, l) = destination_converter(intermediate); + } + } + } + + return compare_reference(problem_shape_MNKL, alpha, beta, use_bias); + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20, + bool use_bias = true) + { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + /// Initializes data structures + /// A/B/C0/D Tensor + initialize(problem_size); + initialize_bias(problem_size); + + if constexpr (IsBinaryOp1Enabled) { + initialize_c1(problem_size); + } + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b, + impl_.mma_promotion_interval + }, + { // Epilogue arguments + { alpha, beta }, // ThreadOp arguments + impl_.collective_epilogue.stride_c, + impl_.collective_epilogue.tensor_D.device_data(), + impl_.collective_epilogue.stride_d, + use_bias ? bias.device_data() : nullptr, + impl_.collective_epilogue.tensor_C.device_data(), + tensor_C1.device_data() + }, // Epilogue arguments end + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha, beta, use_bias); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << float(alpha) + << ", beta: " << float(beta) + << ", use_bias: " << use_bias + << "\n"; + } + + return passed; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllTensorBroadcast(bool use_bias=true) { + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xTensorBroadcast testbed; + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (bool use_bias : {true, false}) { + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20, // iterations + use_bias + ); + + if (!passed) { + return false; + } + } + } + } + } + + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20 // iterations + ); + if (!passed) { + return false; + } + } + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6ae7b864cb272782da4920ffc038830d3b5984b2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = + typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, + -scope, 0); + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), + view.capacity()); + } else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waives test if CUDA device is insufficient + if (!sufficient()) { + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor + tensor_A(problem_size.mk()); + + cutlass::HostTensor + tensor_B(problem_size.kn()); + + cutlass::HostTensor + tensor_C(problem_size.mn()); + + cutlass::HostTensor + tensor_D(problem_size.mn()); + + cutlass::HostTensor + reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, tensor_A.device_ref(), tensor_B.device_ref(), + tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, + reference_D.host_ref(), ElementAccumulator(0)); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + + fname << "error_Gemm_device_" << problem_size.m() << "x" + << problem_size.n() << "x" << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN + << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM + << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK + << ".txt"; + + std::ofstream file(fname.str()); + + file << "problem: " << problem_size << ", alpha: " << alpha + << ", beta: " << beta << "\n\n"; + + file << "A =\n" + << tensor_A.host_view() << "\nB =\n" + << tensor_B.host_view() << "\nC =\n" + << tensor_C.host_view() << "\n\nReference =\n" + << reference_D.host_view() << "\nComputed =\n" + << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = {16, 528}; + + int problem_size_n[] = {16, 528}; + + int problem_size_k[] = {Gemm::InstructionShape::kK, + Gemm::ThreadblockShape::kK * Gemm::kStages + + Gemm::InstructionShape::kK}; + + double problem_alpha[] = {1.0}; + + // TODO Try non zero beta value after multistaged epilogue is implemented + double problem_beta[] = {0.0}; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + passed = + run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..e309208bb4311253be5b7366841164eb62748bab --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageInterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageInterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 0.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py new file mode 100644 index 0000000000000000000000000000000000000000..a180028205abb689436c73403eea82758ade7da9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py @@ -0,0 +1,341 @@ +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# this file creates the test/unit/gemm/device simt tests + + +outputDir = "" + +################################################################################ +# parameters +# Edge - for tiles, the edges represent the length of one side +# Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles +# MaxEdge - maximum length of each edge +# Min/Max - minimum/maximum of the product of edge lengths +################################################################################ + +warpsPerThreadblockEdge = [1, 2, 4, 8, 16] +warpsPerThreadblockRatio = 2 +warpsPerThreadblockMax = 16 +# NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases + +warpShapeEdges = [8, 16, 32, 64, 128, 256] +warpShapeRatio = 4 +warpShapeMax = 64*64 +warpShapeMin = 8*8 + +threadblockEdgeMax = 256 + +# char, type bits/elem, max tile, L0 threadblock tiles +precisions = [ + ["c", "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["q", "cutlass::Quaternion", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["d", "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], + ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], + ["i", "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], + ["s", "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], + ["z", "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], + ] +# L1 will have a single kernel for every unique shape +# L2 will have everything else + +transposes = [ + [False, False], + [False, True], + [True, False], + [True, True] + ] + +################################################################################ +# warps per threadblock +################################################################################ +warpsPerThreadblocks = [] +for warpsPerThreadblock0 in warpsPerThreadblockEdge: + for warpsPerThreadblock1 in warpsPerThreadblockEdge: + if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax: + warpsPerThreadblocks.append([warpsPerThreadblock0, + warpsPerThreadblock1]) +print("WarpsPerThreadblocks",warpsPerThreadblocks) + +################################################################################ +# warp shapes +################################################################################ +warpNumThreads = 32 +warpShapes = [] +for warp0 in warpShapeEdges: + for warp1 in warpShapeEdges: + if warp0 / warp1 <= warpShapeRatio and warp1 / warp0 <= warpShapeRatio and warp0*warp1 <= warpShapeMax and warp0*warp1 > warpShapeMin: + warpShapes.append([warp0, warp1]) +print("WarpShapes", warpShapes) + +numL0 = 0 +numL1 = 0 +numL2 = 0 + +################################################################################ +# create kernels +# create a file for each precision/transpose +# each file contains many tile sizes +################################################################################ + +# precisions +for precision in precisions: + + # get precision char + precisionChar = precision[0] + precisionType = precision[1] + precisionBits = precision[2] + threadblockMaxElements = precision[3] + threadblockTilesL0 = precision[4] + + # transposes + for transpose in transposes: + + # get transpose char + columnMajorA = transpose[0] + columnMajorB = transpose[1] + transCharA = "n" if columnMajorA else "t" + transCharB = "n" if columnMajorB else "t" + + # open file + fileName="simt_%sgemm_%s%s_sm50.cu" % (precisionChar, transCharA, transCharB) + print("\n", fileName) + filePath = "%s%s" % (outputDir, fileName) + out = open(filePath, "w+") + + # write file header + out.write("/***************************************************************************************************\n" +" * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n" +" * SPDX-License-Identifier: BSD-3-Clause \n" +" * \n" +" * Redistribution and use in source and binary forms, with or without \n" +" * modification, are permitted provided that the following conditions are met: \n" +" * \n" +" * 1. Redistributions of source code must retain the above copyright notice, this \n" +" * list of conditions and the following disclaimer. \n" +" * \n" +" * 2. Redistributions in binary form must reproduce the above copyright notice, \n" +" * this list of conditions and the following disclaimer in the documentation \n" +" * and/or other materials provided with the distribution. \n" +" * \n" +" * 3. Neither the name of the copyright holder nor the names of its \n" +" * contributors may be used to endorse or promote products derived from \n" +" * this software without specific prior written permission. \n" +" * \n" +" * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" \n" +" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE \n" +" * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE \n" +" * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE \n" +" * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL \n" +" * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR \n" +" * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER \n" +" * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, \n" +" * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE \n" +" * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \n" +" *\n" +" **************************************************************************************************/\n" +"/*! \\file\n" +" \\brief Tests for device-wide GEMM interface\n" +"*/\n" +"\n" +"#include \n" +"\n" +"#include \"cutlass/cutlass.h\"\n" +"#include \"cutlass/gemm/device/gemm.h\"\n" +"#include \"cutlass/numeric_types.h\"\n" +"\n" +"#include \"../../common/cutlass_unit_test.h\"\n" +"\n" +"#include \"cutlass/util/host_tensor.h\"\n" +"#include \"cutlass/util/tensor_view_io.h\"\n" +"#include \"cutlass/util/reference/host/tensor_fill.h\"\n" +"#include \"cutlass/util/reference/host/tensor_copy.h\"\n" +"#include \"cutlass/util/reference/host/tensor_compare.h\"\n" +"#include \"cutlass/util/reference/host/gemm.h\"\n" +"\n" +"#include \"testbed.h\"\n" +"\n") + foundThreadblockTilesL0 = {} + foundThreadblockTilesL1 = {} + + ######################################################################## + # for each combination of tile sizes + ######################################################################## + for warpsPerThreadblock in warpsPerThreadblocks: + for warpShape in warpShapes: + warpThreadsM = 0 + if warpShape[0] > warpShape[1]: + warpThreadsM = 8 + else: + warpThreadsM = 4 + warpThreadsN = warpNumThreads / warpThreadsM + + # skip shapes with conflicting rectangularity + # they are unlikely to be fastest + blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] + blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] + warpG = warpShape[0] > warpShape[1] + warpL = warpShape[0] < warpShape[1] + + blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2 + blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1] + warpG2 = warpShape[0] > warpShape[1]*2 + warpL2 = warpShape[0]*2 < warpShape[1] + + if blockG2 and warpL: continue + if blockL2 and warpG: continue + if warpG2 and blockL: continue + if warpL2 and blockG: continue + + # check threadblock ratios and max + threadblockTile = [warpShape[0]*warpsPerThreadblock[0], + warpShape[1]*warpsPerThreadblock[1]] + if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue + if threadblockTile[0] > threadblockEdgeMax: continue + if threadblockTile[1] > threadblockEdgeMax: continue + totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1] + + # calculate unroll + # ensure that every iteration at least a full load of A,B are done + unrollMin = 8 + unrollMin0 = totalThreads / threadblockTile[0] + unrollMin1 = totalThreads / threadblockTile[1] + unroll = max(unrollMin, unrollMin0, unrollMin1) + + threadTileM = warpShape[0] / warpThreadsM + threadTileN = warpShape[1] / warpThreadsN + if threadTileM < 2 or threadTileN < 2: continue + if threadTileM*threadTileN*precisionBits > 8*8*32: continue + + # epilogue currently only supports N < WarpNumThreads + if threadblockTile[1] < warpNumThreads: continue + + # limit smem + smemBitsA = threadblockTile[0]*unroll*2*precisionBits + smemBitsB = threadblockTile[1]*unroll*2*precisionBits + smemKBytes = (smemBitsA+smemBitsB)/8/1024 + if (smemKBytes > 48): continue + + # test level 0 + testLevel = -1 + for tileId in range(0, len(threadblockTilesL0)): + tbTile = threadblockTilesL0[tileId] + if tbTile[0] == threadblockTile[0] and tbTile[1] == threadblockTile[1]: + if tuple(tbTile) not in foundThreadblockTilesL0: + testLevel = 0 + numL0 += 1 + foundThreadblockTilesL0[tuple(tbTile)] = True + + # test level 1 + if testLevel < 0: + threadblockTileAlreadyUsed = False + if tuple(threadblockTile) not in foundThreadblockTilesL1: + testLevel = 1 + numL1 += 1 + foundThreadblockTilesL1[tuple(threadblockTile)] = True + + # test level 2 + if testLevel < 0: + testLevel = 2 + numL2 += 1 + + ################################################################ + # write this tile to file + ################################################################ + + print("%ix%ix%i__%ix%i_%ix%i_%ix%i L%i" % ( + threadblockTile[0], threadblockTile[1], unroll, + threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], testLevel)) + + out.write("////////////////////////////////////////////////////////////////////////////////\n" + "// Elements / Thread: %3i x %3i\n" + "// Threads / Warp: %3i x %3i\n" + "// Warps / Block: %3i x %3i\n" + "// Threadblock: %3i x %3i x %2i\n" + % ( threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], + threadblockTile[0], threadblockTile[1], unroll + ) + ) + + out.write("CUTLASS_TEST_L%i(SM50_device_%sgemm_%s%s, %ix%ix%i_%ix%ix1_%ix%i_%ix%i_%ix%i, {\n" % ( + testLevel, + precisionChar, + transCharA, + transCharB, + threadblockTile[0], + threadblockTile[1], + unroll, + warpShape[0], + warpShape[1], + threadTileM, + threadTileN, + warpThreadsM, + warpThreadsN, + warpsPerThreadblock[0], + warpsPerThreadblock[1] + )) + out.write(" using precision = %s;\n" % precisionType) + out.write(" using ThreadblockShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n" % ( + threadblockTile[0], + threadblockTile[1], + unroll)) + out.write(" using WarpShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n\n" % ( + warpShape[0], + warpShape[1], + unroll)) + out.write(" static int const kEpilogueElementsPerAccess = 1;\n" + " using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n" + " using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<\n" + " precision, kEpilogueElementsPerAccess, precision, precision>;\n\n") + + out.write(" using Gemm = cutlass::gemm::device::Gemm<\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::RowMajor,\n" + " precision,\n" + " cutlass::arch::OpClassSimt,\n" + " cutlass::arch::Sm50,\n" + " ThreadblockShape, WarpShape, InstructionShape,\n" + " EpilogueOutputOp,\n" + " cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,\n" + " 2 // Stages\n" + " >;\n" % ( + "Column" if columnMajorA else "Row", + "Column" if columnMajorB else "Row", + )) + out.write(" EXPECT_TRUE(test::gemm::device::TestAllGemm());\n" + "} )\n\n") + + + out.close() +print("NumKernels:", numL0, numL1, numL2) + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63ffc3281dd2b9e9f74e0024c73da00628331dd4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Host reference and operations for Sm90 EVT unit test +*/ +#pragma once +#include "gemm_testbed_3x_evt.hpp" + +////////////////////////////////////////////////////////////////////////////// +/// Host references used for testing +namespace test::gemm::device { +template +using HEVT = HostTreeVisitor; + +template +using HDAG = HostTopoVisitor; + +template +using HST = HostSplitTreeVisitor; + +/// D = alpha * acc + beta * C + AuxLoad +template +class HostEVTAuxLoad { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = alpha * acc + beta * C + per-column bias +template +class HostPerColBias { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using RowBroadcastNode = HostRowBroadcast; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = beta * C + Graph(relu(alpha * acc + aux) + aux) +/// Testing EVT - DAG structure +template +class HostEVTDAG { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using DAGNode = HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. alpha + cute::tuple<>, // 1. acc + cute::tuple<>, // 2. aux load + cute::tuple, // 3. alpha * acc + aux load + cute::tuple, // relu(alpha * acc + aux load) + cute::tuple // relu(alpha * acc + aux load) + aux load + >, + ScalarAlpha, + AccFetchNode, + AuxLoadNode, + HostCompute, + HostCompute, + HostCompute + >; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +/// Testing DAG - EVT +template +class HostDAGEVT { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTNode = HEVT< + HostAuxStore, + HEVT< + HostCompute, + HostScalarBroadcast<2>, + HostAccumulator<>, + HostAuxLoad + > + >; + using EVTModule = HEVT< + HostAuxStore, + HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. EVT + cute::tuple<>, // 1. per-row bias + cute::tuple, // 2. EVT + per-row bias + cute::tuple // 3. maximum(EVT + per-row bias, EVT) + >, + EVTNode, + HostColBroadcast>, + HostCompute, + HostCompute + > + >; +}; + +/// Xreduce(alpha * acc + beta * C) +template +class HostReduce { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; + using ReduceNode = HEVT; + using EVTModule = HEVT, ReduceNode>; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template class ActivationFn, class ElementD> +class HostScaledLinCombPerRowBiasEltAct { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTModule = HEVT< + HostAuxStore, + HEVT< + HostCompute::template Op>, // activation(Z) * scaled_d + HEVT< + HostCompute, // activation(Z) + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template class ActivationFn, class ElementD, class ElementAux = ElementD> +class HostScaledLinCombPerRowBiasEltActAmaxAux { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + template + using amax = cutlass::maximum_absolute_value_reduction; + using EVTModuleAuxFp8 = HEVT< + HostAuxStore, + HST, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + >, + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HostAccumulator<> // Z + > + >, + HostScalarBroadcast<1> // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + HEVT< + HostAuxStore, + HEVT< + HostCompute, + HEVT< + HostScalarReduce, + HostAccumulator<> + >, + HostScalarBroadcast<1> + > + > + > + >; + + using EVTModuleAuxNotFp8 = HEVT< + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HostAuxStore, + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HEVT< + // Aux = Z + HostAuxStore, + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; + + using EVTModule = cute::conditional_t, EVTModuleAuxFp8, EVTModuleAuxNotFp8>; + +}; +} // namespace test::gemm::device + +////////////////////////////////////////////////////////////////////////////// +namespace cutlass::epilogue { +namespace fusion { + +namespace detail { + +template +struct maximum_with_default_nan_propagation : maximum {}; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoad +template< + class EpilogueDescriptor, + class AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoad = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, + typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, + typename AuxLoadDescriptor::CopyOpS2R // aux load + > + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoadNoSmem +template< + class EpilogueDescriptor, + class ElementAux, + class StrideAux, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoadNoSmem = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad<0, void, ElementAux, StrideAux, void, void> // aux load + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// beta * C + Graph(alpha * acc + gamma + acc) +template< + typename EpilogueDescriptor, + typename AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEVTDAG = + Sm90EVT, // beta * C + (alpha * acc + aux) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, // 0. alpha + cute::seq<>, // 1. acc + cute::seq<>, // 2. aux load + cute::seq<1, 0, 2>, // 3. alpha * acc + aux load + cute::seq<3>, // relu(alpha & acc + aux load) + cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load + >, + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, + typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, + Sm90Compute, + Sm90Compute, + Sm90Compute + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +template< + class EpilogueDescriptor, + class AuxStoreDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDAGEVT = + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, + cute::seq<>, + cute::seq<1, 0>, + cute::seq<0, 2> + >, + Sm90EVT< + Sm90AuxStore< + AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, + typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, + Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch, + Sm90SrcFetch + > + >, + Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute>, + Sm90Compute, + Sm90Compute + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + per-column bias +template< + class EpilogueDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute> + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-column reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-row reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = scalar reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombScalarReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; +} // namespace fusion + +} // namespace cutlass::epilogue diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..0007666cdd084f35015200e36fd47f75971f6c1c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h @@ -0,0 +1,639 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_utils.h" +#include "testbed_universal.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + typename Gemm::LayoutA::Stride stride_factor_A; + typename Gemm::LayoutB::Stride stride_factor_B; + typename Gemm::LayoutC::Stride stride_factor_C; + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + Testbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(typename Gemm::LayoutA::Stride()), + stride_factor_B(typename Gemm::LayoutB::Stride()), + stride_factor_C(typename Gemm::LayoutC::Stride()), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + Testbed( + typename Gemm::LayoutA::Stride stride_factor_A_, + typename Gemm::LayoutB::Stride stride_factor_B_, + typename Gemm::LayoutC::Stride stride_factor_C_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_C(stride_factor_C_), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); + tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); + tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0) + << "tensor_D (size " << tensor_D.size() << ") has nonpositive norm"; + } + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0) + << "reference_D (size " << reference_D.size() << ") has nonpositive norm"; + } + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << "reference_D does not equal tensor_D"; + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "split_k_slices: " << split_k_slices << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op.initialize returned with error " << to_string(status) + << ", indicating that this test is not supported. Last CUDA error: " + << cudaGetErrorString(cudaGetLastError()); + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + try { + status = gemm_op(); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "gemm_op() threw a std::exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "gemm_op() threw an exception of unknown type"; + throw; + } + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op failed with error " << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + EXPECT_TRUE(passed) << "Error: split_k_slices = " << split_k_slices + << ", alpha: " << alpha; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmBasic( + const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = { + kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemm( + const typename Gemm::LayoutA::Stride& stride_factor_A, + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) +{ + // Test basic GEMM with non-default stride factors + return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); +} + +template +bool TestAllGemm() +{ +#ifdef NDEBUG + // Non-debug builds also test basic GEMM with default stride factors + if (!TestAllGemmBasic()) { + return false; + } +#endif // NDEBUG + + // Test universal GEMM +#if 0 + // Define the universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle + >; +#else + // Define the streamk universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle + >; +#endif + + // Define the universal adaptor + using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Test universal GEMM + return TestAllGemmUniversal(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmPerf(int iterations = 1) { + bool passed = true; + + int problem_size_m[] = { 2048 }; + + int problem_size_n[] = { 4352 }; + + int problem_size_k[] = { 4096 }; + + int split_k_slices[] = { 1 }; + double problem_alpha[] = { 1 }; + double problem_beta[] = { 0.0 }; + + Testbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + for (int i = 0; i < iterations; i++){ + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..add984ca3b9a0c05325b93cf52cbadd710527ba6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedComplex : public Testbed { + + using Base = Testbed; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + + // + // Methods + // + + TestbedComplex( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex( + problem_size, + alpha, + this->tensor_A.host_ref(), + Gemm::kTransformA, + this->tensor_B.host_ref(), + Gemm::kTransformB, + beta, + this->tensor_C.host_ref(), + this->reference_D.host_ref(), + ElementAccumulator(0) + ); + + return this->compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Initialize workspace + // + + this->initialize(problem_size); + + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmComplex() { + bool passed = true; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = + cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + int problem_size_m[] = { + kAlignment, 512 - 3*kAlignment + }; + + int problem_size_n[] = { + kAlignment, 512 - 2*kAlignment + }; + + int problem_size_k[] = { + kAlignment, 128 - kAlignment + }; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + TestbedComplex testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..eca0b0ae0decf3293f6f73cb6ebbc5b5735a8e49 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -0,0 +1,670 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithBroadcastReferenceOp { + + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + GemmWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { + + ElementCompute t_full = binary_op(gemm, bias); + + if (OutputOp::kStoreT) { + T = ElementT(t_full); + } + + if (OutputOp::kStoreZ) { + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = GEMM(AB, C) +// +// T[i, j] = BinaryOp(Y[i, j], Broadcast[i]) +// +// Z[i, j] = Elementwise(T[i, j]) +// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +struct TestbedGemmWithBroadcast { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename OutputOp::ElementCompute; + using ElementVector = typename OutputOp::ElementVector; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; // Input A + cutlass::HostTensor tensor_B; // Input B + cutlass::HostTensor tensor_C; // Input C + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + + cutlass::HostTensor tensor_Z; + cutlass::HostTensor tensor_T; + + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor tensor_Y_ref; + cutlass::HostTensor tensor_Z_ref; + cutlass::HostTensor tensor_T_ref; + + + // + // Methods + // + + TestbedGemmWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_Z.resize(problem_size.mn()); + tensor_T.resize(problem_size.mn()); + tensor_Broadcast.resize({ + problem_size.m(), + 1 + }); + + tensor_C_ref.resize(problem_size.mn()); + tensor_Y_ref.resize(problem_size.mn()); + tensor_Z_ref.resize(problem_size.mn()); + tensor_T_ref.resize(problem_size.mn()); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + + tensor_Z.sync_device(); + tensor_T.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Z.sync_host(); + tensor_T.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (OutputOp::kStoreZ) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); + } + + if (OutputOp::kStoreT) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); + } + + bool passed = true; + float norm_diff = 0; + + if (OutputOp::kStoreZ) { + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); + passed = (norm_diff <= 0.1f); + EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; + } + + if (OutputOp::kStoreT) { + + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); + passed = (passed && (norm_diff <= 0.1f)); + + EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; + } + + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); + + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nZ =\n" << tensor_Z.host_view() + << "\nT =\n" << tensor_T.host_view() + << "\n\n" + << "\nY_ref =\n" << tensor_Y_ref.host_view() + << "\nZ_ref =\n" << tensor_Z_ref.host_view() + << "\nT_ref =\n" << tensor_T_ref.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + tensor_Y_ref.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + + ElementZ z; + ElementT t; + + reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); + + if (OutputOp::kStoreZ) { + tensor_Z_ref.at({m, n}) = z; + } + + if (OutputOp::kStoreT) { + tensor_T_ref.at({m, n}) = t; + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_Z.device_data(), + tensor_Broadcast.device_data(), + tensor_T.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = true; + + passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), + profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), + profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestGemmWithBroadcast( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestAllGemmWithBroadcast() { + + int M_problems[] = {8, 136, 264, 520}; + int N_problems[] = {8, 136, 264, 520}; + int K_problems[] = {8, 136, 264, 520}; + double alpha_problems[] = {1.25, 2.25}; + double beta_problems[] = {0, 1, 2.0}; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + 1, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; + + if (!passed) { + + return passed; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..af3629ccfb87e09e80b85af508379780d6428dc5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithReductionReference { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; + using ElementC = typename Gemm::ElementC; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + // + // Data members + // + + BinaryOp binary_op; + + // + // Methods + // + + GemmWithReductionReference() { } + + ElementCompute operator()( + ElementAccumulator d_y, + ElementT t) { + + return binary_op(ElementCompute(d_y), ElementCompute(t)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp +> +struct TestbedGemmWithReduction { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor reference_d_Y; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Reduction; + + // + // Methods + // + + TestbedGemmWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + for (int m = 0; m < view.extent().row(); ++m) { + for (int n = 0; n < view.extent().column(); ++n) { + //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); + view.at({m, n}) = (n == 0 ? Element(m) : Element()); + + } + } + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + + tensor_Reduction.resize({ + problem_size.m(), + (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN + }); + + tensor_Tensor.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + reference_d_Y.resize(problem_size.mn(), false); + tensor_C_ref.resize(problem_size.mn(), false); + reference_Reduction.resize({problem_size.m(), 1}, false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Reduction.sync_device(); + tensor_Tensor.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Reduction.sync_host(); + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); + + bool passed = true; + for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { + + ElementAccumulator reduced_value = ElementAccumulator(); + for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { + reduced_value += tensor_Reduction.at({m, j}); + } + + if (reduced_value != reference_Reduction.at({m, 0})) { + std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; + passed = false; + break; + } + } + EXPECT_TRUE(passed) << "Reduction is incorect."; + + if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { + EXPECT_TRUE(false) << " mismatched reference"; + passed = false; + } + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors_sm70.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nT = \n" << tensor_Tensor.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view() + << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" + << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + reference_d_Y.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute backwards + for (int m = 0; m < problem_size.m(); ++m) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.n(); ++n) { + ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); + reduced_value += d_full; + reference_D.at({m, n}) = ElementC(d_full); + } + reference_Reduction.at({m, 0}) = reduced_value; + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_D(tensor_D.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Reduction(tensor_Reduction.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Tensor(tensor_Tensor.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_D.get() + tensor_D.capacity() * (i % kWorkspaces), + profiling_tensor_Reduction.get() + tensor_Reduction.capacity() * (i % kWorkspaces), + profiling_tensor_Tensor.get() + tensor_Tensor.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmWithReduction( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count = 1, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithReduction testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..c7317eb855477e63fe19858ca51cd5722f236eb5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h @@ -0,0 +1,501 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24); + + if (!i) { + problem = cutlass::gemm::GemmCoord(48, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + // std::cout << "Problem[" << i << "]: " << problem << std::endl; + + lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.m() * problem.k(); + int64_t elements_B = problem.k() * problem.n(); + int64_t elements_C = problem.m() * problem.n(); + int64_t elements_D = problem.m() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a GEMM + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::host::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Gemm::kTransformA, + view_B, + Gemm::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0) + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = gemm.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the GEMM object + status = gemm.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..f8f08f23c4477745648f1cf8f9e439ae6b5061e2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h @@ -0,0 +1,502 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Rank2K::LayoutA; + using LayoutB = typename Rank2K::LayoutB; + using LayoutC = typename Rank2K::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + auto N = 8 * (rand() % 64) + 24; + auto K = 8 * (rand() % 64) + 24; + cutlass::gemm::GemmCoord problem(N, N, K); + + if (!i) { + problem = cutlass::gemm::GemmCoord(16, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.n() * problem.k(); + int64_t elements_B = problem.n() * problem.k(); + int64_t elements_C = problem.n() * problem.n(); + int64_t elements_D = problem.n() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a Rank2K + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference Rank2K + cutlass::reference::host::Rank2KComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Rank2K::kTransformA, + view_B, + Rank2K::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the Rank2K arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure Rank2K arguments + typename Rank2K::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the Rank2K object + Rank2K rank2k; + + size_t workspace_size = rank2k.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = rank2k.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the Rank2K object + status = rank2k.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..e9315e12e8711f50256e4cfe05666201acd614d3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/device_kernel.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static cutlass::FillMode const kFillModeC = FillModeC; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} + + CUTLASS_DEVICE + cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { + int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; + int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; + int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); + + if (FillModeC == cutlass::FillMode::kUpper) { + cutlass::swap(macro_row, macro_col); + } + + int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); + int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); + + return cutlass::gemm::GemmCoord(row, col, 0); + } +}; + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); + + problem_visitor.advance(gridDim.x); + + // + // Early exit conditions + // 1) Out of range + // 2) Upper-triangular block in lower-triangular problem + // 3) Lower-triangular block in upper-triangular problem + // + + if (grid_shape.m() <= tile_offset.m() || + grid_shape.n() <= tile_offset.n()) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && + (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && + tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + if (skip_tile_check) { + return true; + } + + return verify(); + } +}; + +template +struct TestbedGroupedRank2KScheduler { + + using BaselinePV = BaselineProblemVisitor, + ThreadblockShape, + PrefetchTileCount, + ThreadCount, + FillModeC>; + + // + // Data members + // + + // Whether to skip checking that the tiles are visited as expected. This is useful + // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped + // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to + // exit early, but which are difficult to detect in tests without reimplementing + // this functionality. + bool skip_tile_check; + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): + skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + int n = scale_factor * (rand() % 64) + 24; + + cutlass::gemm::GemmCoord problem( + n, + n, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + FillModeC>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run(skip_tile_check)); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run(skip_tile_check)); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..bda2704b517ea95052e2c2060b50712b686344f6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h @@ -0,0 +1,407 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped GEMM problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + problem_visitor.advance(gridDim.x); + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + return verify(); + } +}; + +template +struct TestbedGroupedGemmScheduler { + + using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; + using BaselinePV = BaselineProblemVisitor; + + // + // Data members + // + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): + seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + Transpose>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run()); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run()); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..2a5956000db8e8c05ea22538e58149998b03e3fc --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct InterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + InterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..32452c30e05f64763a268195ae78138f26c09735 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TestbedPlanarComplex { +public: + + using ElementA = typename Gemm::ElementA; + using LayoutA = typename Gemm::LayoutA; + using ElementB = typename Gemm::ElementB; + using LayoutB = typename Gemm::LayoutB; + using ElementC = typename Gemm::ElementC; + using LayoutC = typename Gemm::LayoutC; + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::HostTensorPlanarComplex tensor_A; + cutlass::HostTensorPlanarComplex tensor_B; + cutlass::HostTensorPlanarComplex tensor_C; + cutlass::HostTensorPlanarComplex tensor_D; + cutlass::HostTensorPlanarComplex tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { + + tensor_A.reset({problem_size.m(), problem_size.k()}); + tensor_B.reset({problem_size.k(), problem_size.n()}); + tensor_C.reset({problem_size.m(), problem_size.n()}); + tensor_D.reset({problem_size.m(), problem_size.n()}); + tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); + } + + void initialize() { + + uint64_t seed = 1073; + + int scope_max = 8; + int scope_min = -8; + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); + cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + bool run( + cutlass::complex alpha = {1, 0}, + cutlass::complex beta = {0, 0}) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + initialize(); + + int batch_count = 1; + + ElementA *ptr_A = tensor_A.device_data(); + ElementB *ptr_B = tensor_B.device_data(); + ElementC *ptr_C = tensor_C.device_data(); + ElementC *ptr_D = tensor_D.device_data(); + + typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); + typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); + typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); + typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); + + int64_t imag_stride_A = tensor_A.imaginary_stride(); + int64_t imag_stride_B = tensor_B.imaginary_stride(); + int64_t imag_stride_C = tensor_C.imaginary_stride(); + int64_t imag_stride_D = tensor_D.imaginary_stride(); + + // + // Launch device kernel + // + + Gemm gemm_op; + + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + {alpha, beta}, + ptr_A, + ptr_A + imag_stride_A, + ptr_B, + ptr_B + imag_stride_B, + ptr_C, + ptr_C + imag_stride_C, + ptr_D, + ptr_D + imag_stride_D, + lda, + lda, + ldb, + ldb, + ldc, + ldc, + ldd, + ldd + }; + + cutlass::Status status = gemm_op(args); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t error = cudaDeviceSynchronize(); + + tensor_D.sync_host(); + + // + // Compute reference + // + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + tensor_D_ref.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_D.host_view(), + tensor_D_ref.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("gemm_planar_complex.txt"); + + output + << "A:\n" << tensor_A.host_view() << "\n" + << "B:\n" << tensor_B.host_view() << "\n" + << "C:\n" << tensor_C.host_view() << "\n" + << "Reference:\n" + << tensor_D_ref.host_view() << "\n" + << "Computed:\n" + << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +template +bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { + + TestbedPlanarComplex testbed(problem_size); + + return testbed.run(); +} + +template +bool TestAllGemmPlanarComplex() { + + int M[] = { + 16, 64, 72, 144, 264, 520, + }; + + int N[] = { + 16, 64, 72, 144, 248, 264, 520 + }; + + int K[] = { + 8, 64, 72, 96, 264, 520 + }; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + cutlass::complex alpha_values[] = { + {ElementCompute(1.25), ElementCompute(-0.5)} + }; + + cutlass::complex beta_values[] = { + {ElementCompute(-2.25), ElementCompute(1.5)} + }; + + for (int m : M) { + for (int n : N) { + for (int k : K) { + + test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = testbed.run(alpha, beta); + if (!passed) { + return false; + } + } + } + } + } + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..4d9f6743a45e5dc3a7b4ddd3e2a7b2abceffbb18 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h @@ -0,0 +1,641 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_2k.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Rank2K::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Rank2K::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Rank2K workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Rank2K::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Rank2K::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Rank2K::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Rank2K + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename Rank2K::ElementA, typename Rank2K::LayoutA, + typename Rank2K::ElementB, typename Rank2K::LayoutB, + typename Rank2K::ElementC, typename Rank2K::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Rank2K::kTransformA, + tensor_B.host_ref(), + Rank2K::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Rank2K::Rank2Kkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRank2KUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Rank2K operator + // + + typename Rank2K::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Rank2K rank2k_op; + + size_t workspace_size = Rank2K::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Rank2K + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_Rank2k_device_" + << "fill_mode_c_" + << (Rank2K::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (Rank2K::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Rank2K::ThreadblockShape::kM << "x" + << Rank2K::ThreadblockShape::kN << "x" + << Rank2K::ThreadblockShape::kK << "_" + << Rank2K::WarpShape::kM << "x" + << Rank2K::WarpShape::kN << "x" + << Rank2K::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRank2KUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.25 + }; + + double problem_beta[] = { + 0.0, 2.15 + }; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +template +bool TestAllRank2KHermitianUniversal() { + bool passed = true; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + /* Complex alpha for HER2K */ + ElementAccumulator problem_alpha[] = { + {1.0}, + {1.25, 3.25}, + {-0.25, -2.25} + }; + + ElementAccumulator problem_beta[] = { + 0.0, -2.25 + }; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + alpha, + beta + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..cb46528a049ae1254d0492b6235821210e47b957 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h @@ -0,0 +1,511 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename RankK::ElementA; + using ElementC = typename RankK::ElementC; + using ElementAccumulator = typename RankK::ElementAccumulator; + using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, RankK::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, RankK::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the RankK workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename RankK::ElementA(1); + tensor_C.host_view().at({0, 0}) = typename RankK::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a RankK + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename RankK::ElementA, typename RankK::LayoutA, + typename RankK::ElementC, typename RankK::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + RankK::kTransformA, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + RankK::kFillModeC, + RankK::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename RankK::RankKkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRankKUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the RankK operator + // + + typename RankK::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + RankK rank2k_op; + + size_t workspace_size = RankK::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the RankK + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_RankK_device_" + << "fill_mode_c_" + << (RankK::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (RankK::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << RankK::ThreadblockShape::kM << "x" + << RankK::ThreadblockShape::kN << "x" + << RankK::ThreadblockShape::kK << "_" + << RankK::WarpShape::kM << "x" + << RankK::WarpShape::kN << "x" + << RankK::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRankKUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + int const kAlignmentN = 128 / kMinimumOperandElementSize; + int const kAlignmentK = 128 / kMinimumOperandElementSize; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages - kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h new file mode 100644 index 0000000000000000000000000000000000000000..0a01a6a32ee2db84f2e890059423cd6b8477f766 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/core_io.h" + +#include "testbed.h" + + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// List of Gemm internal paramters this testbed supports user verification +// +enum class ParameterID { + + // Threadblock-level parameters + kSmemASize, + kSmemBSize, + + // Warp-level parameters + kWarpFragmentASize, + kWarpFragmentBSize, + kWarpFragmentCSize, + kInvalid +}; + +struct Reference { + ParameterID parameter_id; + + union { + int value; + + struct { + int m, n, k; + } gemm_shape; + + struct { + int row, column; + } matrix_shape; + }; + + std::string error_msg; + + Reference( + ParameterID parameter_id_, + int value_=-1, + std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} +}; + + +template +struct TestbedSanity { + + // + // Type definitions (All Gemm types top down) + // + + // Unpacking Gemm types in the following order + // Kernel-level > Threadblock-level > Warp-level > Instruction-level + + // kernel-level cutlass Gemm + using GemmKernel = typename Gemm::GemmKernel; + + // + // Threadblock-level gemm types + // + using MmaThreadBlock = typename GemmKernel::Mma; + + // Threadblock-level gemm shape covering one stage + using ThreadblockShape = typename MmaThreadBlock::Shape; + + // Shared memory size covering all stages + using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; + using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; + using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; + using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; + + + /// Number of stages + static int const kStages = MmaThreadBlock::Base::kStages; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; + + + // + // Warp-level gemm types + // + + // Warp-level gemm operator + using MmaWarp = typename MmaThreadBlock::Operator; + + // Warp-level gemm shape covering all kgroups + using WarpShape = typename MmaWarp::Shape; + + // Warp-level framents holding operands A & B operand and destination C + using WarpFragmentA = typename MmaWarp::FragmentA; + using WarpFragmentB = typename MmaWarp::FragmentB; + using WarpFragmentC = typename MmaWarp::FragmentC; + + // + // Instruction-level gemm types + // + + // Instruction-level gemm operator + using MmaInstruction = typename MmaWarp::Policy::Operator; + + // Instruction shape + using InstructionShape = typename MmaInstruction::Shape; + + // Instruction-level framents holding operands A & B operand and destination C + using InstructionFragmentA = typename MmaInstruction::FragmentA; + using InstructionFragmentB = typename MmaInstruction::FragmentB; + using InstructionFragmentC = typename MmaInstruction::FragmentC; + + // + // Testbed types + // + + // Vector of values holding user provided reference + using ReferenceVector = std::vector; + + // + // Data members + // + ReferenceVector references; + + // + // Methods + // + + TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } + + // verify all parameter in ReferenceVector + bool verify() { + for(auto ref : references) + verify_parameter(ref); + return true; + } + + // verify parameter of type Reference + void verify_parameter(Reference const& ref) { + switch(ref.parameter_id) { + case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Overload output operators for TesbedSanity +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { + + + out << "Gemm internal parameters" << std::endl + << " Threadblock-level parameters:" << std::endl + << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl + << " kStages = " << TestbedSanity::kStages << std::endl + << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl + <<" Shared memory sizes:" << std::endl + <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl + <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl + <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl + <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl + <<" Warp-level parameters" << std::endl + <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl + <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl + <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl + <<" Instruction-level parameters" << std::endl + <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl + <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl + <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; + + return out; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h new file mode 100644 index 0000000000000000000000000000000000000000..a95bf996bac337b44da616dc9fbf9c9bdb2a625c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + + Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SparseTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + static int const kSparse = Gemm::GemmKernel::kSparse; + static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; + static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; + static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; + + using ElementE = typename Gemm::GemmKernel::ElementE; + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_E; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + SparseTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), + init_B(init_B_), + init_C(init_C_), + init_E(init_E_), + seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); + tensor_A_uncompressed.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + tensor_E.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + tensor_E_reordered.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, kMetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + EXPECT_TRUE(false); + } + + cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_E_reordered.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nE =\n" << tensor_E.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), problem_size.m(), problem_size.k()); + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_E_reordered.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_E_reordered.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + // This failure is likely due to insufficient device capabilities. Waive the test. + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << ", beta: " << beta << ", m: " << problem_size.m() << ", n: " << problem_size.n() << ", k:" < +bool TestAllSparseGemm() { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = {Gemm::ThreadblockShape::kK * 8}; + + int split_k_slices[] = { + 1, 2 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + SparseTestbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h new file mode 100644 index 0000000000000000000000000000000000000000..8fa4a85505316d08f1d050702b78448f8fae8565 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSplitK : public Testbed { + + using Base = Testbed; + + using ElementCompute = typename Base::ElementCompute; + + // + // Methods + // + + TestbedSplitK( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + return this->verify(problem_size, alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmSplitK() { + bool passed = true; + + cutlass::gemm::GemmCoord problem_sizes[] = { + {8, 8, 2048}, + {8, 8, 2056}, + {264, 72, 520}, + {264, 520, 120}, + {264, 520, 264} + }; + + int split_k_slices[] = { + 1, 2, 4, 5, 7 + }; + + double problem_alpha[] = { + 0.5 + }; + + double problem_beta[] = { + 2.0 + }; + + using Testbed = TestbedSplitK; + using ElementCompute = typename Testbed::ElementCompute; + + Testbed testbed; + + for (auto problem_size : problem_sizes) { + for (int split_k_count : split_k_slices) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = testbed.run( + problem_size, + split_k_count, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + std::cout << "Failed on size " << problem_size << " with split_k_count " << split_k_count << std::endl; + return false; + } + } + } + } + } + + EXPECT_TRUE(passed); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b7a57f7eb0ca73c23460e5a9ce1301061c2cc286 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Symm update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/symm.h" +#include "cutlass/util/reference/host/symm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSymmUniversal { + + using ElementA = typename Symm::ElementA; + using ElementB = typename Symm::ElementB; + using ElementC = typename Symm::ElementC; + using ElementAccumulator = typename Symm::ElementAccumulator; + using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedSymmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Symm::kFillModeA, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Symm::kFillModeA, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Symm workspace + // + + if (Symm::kSideModeA == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Symm::kSideModeA == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Symm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Symm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Symm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Symm + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::SymmComplex< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator, + Symm::kBlasMode>, + cutlass::reference::host::Symm< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_symm; + + reference_symm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Symm::SymmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedSymmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Symm operator + // + + int batch_stride_A; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Symm::kSideModeA == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Symm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Symm symm_op; + + size_t workspace_size = Symm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = symm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Symm + // + + status = symm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_" + << (Symm::kBlasMode == cutlass::BlasMode::kSymmetric ? "symm_" : "hemm_" ) + << "device_" + << "fill_mode_a_" + << (Symm::kSideModeA == cutlass::SideMode::kLeft ? "leftside_" : + (Symm::kSideModeA == cutlass::SideMode::kRight ? "rightside_" : "invalid_")) + << (Symm::kFillModeA == cutlass::FillMode::kLower ? "lower_" : + (Symm::kFillModeA == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Symm::ThreadblockShape::kM << "x" + << Symm::ThreadblockShape::kN << "x" + << Symm::ThreadblockShape::kK << "_" + << Symm::WarpShape::kM << "x" + << Symm::WarpShape::kN << "x" + << Symm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "alpha: " << ElementCompute(alpha) << "\n" + << "beta: " << ElementCompute(beta) << "\n" + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestsymmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedSymmUniversal testbed; + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllSymmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Symm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages - kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.0 + }; + + double problem_beta[] = { + 0, 2.0 + }; + + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + int k = 0; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + k = m; + else if (Symm::kSideModeA == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + #if 0 + // skip very small K problems + if (k / batch_count < 2 * Symm::ThreadblockShape::kK) { + continue; + } + #endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedSymmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b30acfed6bba547986efd3afa8eb829be2a255e4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide TRMM interface + + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/trmm.h" +#include "cutlass/util/reference/host/trmm_complex.h" +#include "cutlass/core_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedTrmmUniversal { + + using ElementA = typename Trmm::ElementA; + using ElementB = typename Trmm::ElementB; + using ElementC = typename Trmm::ElementC; + using ElementAccumulator = typename Trmm::ElementAccumulator; + using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_D; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedTrmmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_D(init_D_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Helper to initialize a tensor view (pad diagonal fill with zeros for up to alignment on wrong side of diagonal) + template + bool initialize_pad_diagonal_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int alignment) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillPadDiagonalRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, 0, alignment); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the TRMM workspace + // + + if (Trmm::kSideMode == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Trmm::kSideMode == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + //EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2017)); + //EXPECT_TRUE(initialize_pad_diagonal_tensor(tensor_A.host_view(), init_A, seed + 2017, Trmm::kAlignmentA)); + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2017, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2019, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Trmm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Trmm::ElementB(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_D.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a TRMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::TrmmComplex< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kTransformA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + Trmm::kTransformB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator>, + cutlass::reference::host::Trmm< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_trmm; + + reference_trmm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Trmm::TrmmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedTrmmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the TRMM operator + // + + int batch_stride_A; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Trmm::kSideMode == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Trmm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Trmm trmm_op; + + size_t workspace_size = Trmm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = trmm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the TRMM + // + + status = trmm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha); + + if (!passed) { + std::stringstream fname; + + fname << "error_Trmm_device_" + << "fill_mode_" + << (Trmm::kFillMode == cutlass::FillMode::kLower ? "lower_" : + (Trmm::kFillMode == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "side_mode_" + << (Trmm::kSideMode == cutlass::SideMode::kLeft ? "left_" : + (Trmm::kSideMode == cutlass::SideMode::kRight ? "right_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Trmm::ThreadblockShape::kM << "x" + << Trmm::ThreadblockShape::kN << "x" + << Trmm::ThreadblockShape::kK << "_" + << Trmm::WarpShape::kM << "x" + << Trmm::WarpShape::kN << "x" + << Trmm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestTrmmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0) { + + bool passed = true; + + TestbedTrmmUniversal testbed; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + return passed; +} + +template +bool TestAllTrmmUniversal() { + bool passed = true; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Trmm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages - kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 2.0 + }; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + for (auto alpha : problem_alpha) { + + int k = 0; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + k = m; + else if (Trmm::kSideMode == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + +#if 0 + // skip very small K problems + if (k / batch_count < 2 * Trmm::ThreadblockShape::kK) { + continue; + } +#endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedTrmmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..00368a5e8eebc128719f64069583010c83dc0c1f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedUniversal { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + bool is_unsigned_int = std::numeric_limits::is_integer && !std::numeric_limits::is_signed; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = is_unsigned_int ? 2 : 1; + scope_min = is_unsigned_int ? 0 : -1; + } else if (bits_output == 16) { + constexpr auto u8_bf16 = + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value) || + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value); + scope_max = is_unsigned_int ? 10 : (u8_bf16 ? 3 : 5); + scope_min = is_unsigned_int ? 0 : (u8_bf16 ? -3 : -5); + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); + tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); + tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + /* + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "mode: " << (int) mode << "\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "batch_count: " << batch_count << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedUniversal testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllGemmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentM, 512 - 3*kAlignmentM + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1, 2, 3, 5, 7 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { + continue; + } + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + } + + /* + // large problem with high coverage + for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { + TestbedUniversal testbed; + + cutlass::gemm::GemmCoord problem_size(72, 56, 8192); + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + cutlass::from_real(1.0), + cutlass::from_real(2.0) + ); + + if (!passed) { + break; + } + } + */ + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..89ac33a1028061515d08d50fdb6cce7833ae88ce --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +inline char const *to_string(cutlass::Status status) { + + switch (status) { + case cutlass::Status::kSuccess: return "kSuccess"; + case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: return "kErrorInternal"; + case cutlass::Status::kInvalid: return "kInvalid"; + default: break; + } + return "invalid"; +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5588f57c40c4e8f8d06adfa9f1e673350fb5e5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h @@ -0,0 +1,609 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level GEMMs with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" +#include "testbed_sparse.h" +#include "testbed_utils.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor +> +struct TestbedWithAmax { + + static_assert(std::is_same_v> || std::is_same_v>); + static constexpr bool IsSparseTestbed = std::is_same_v>; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Gemm::EpilogueOutputOp::ElementAbsmax; + + static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + GemmTestbed underlying_testbed; + + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedWithAmax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + underlying_testbed(init_A_, init_B_, init_C_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + underlying_testbed.initialize(problem_size); + + tensor_Vector.resize({1, problem_size.n()}); + reference_D.resize(problem_size.mn(), false); + tmp_D.resize(problem_size.mn(), false); + + EXPECT_TRUE( + underlying_testbed.initialize_tensor(tensor_Vector.host_view(), underlying_testbed.init_C, underlying_testbed.seed + 2020) + ); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), underlying_testbed.tensor_C.host_view()); + + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), underlying_testbed.seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), underlying_testbed.seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), underlying_testbed.seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), underlying_testbed.seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(problem_size.mn()); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), underlying_testbed.seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(problem_size.mn(), false); + reference_abs_max_Aux.resize({1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + underlying_testbed.tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), underlying_testbed.tensor_D.host_view()); + if (!passed) { + std::cout << "Comparison of D failed" << std::endl; + } + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux failed" << std::endl; + } + if (!cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux absmax failed" << std::endl; + } + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view())) { + passed = false; + std::cout << "Comparison of D absmax failed" << std::endl; + } + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file("testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << underlying_testbed.tensor_A.host_view() + << "\nB =\n" << underlying_testbed.tensor_B.host_view() + << "\nC =\n" << underlying_testbed.tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << "\n\nReference D =\n" << reference_D.host_view() + << "\nComputed D =\n" << underlying_testbed.tensor_D.host_view(); + if (kScaleAux) { + file + << "\n\nReference Aux =\n" << reference_Aux.host_view() + << "\nComputed Aux =\n" << tensor_Aux.host_view() + << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() + << "\nComputed Absmax Aux = " << abs_max_Aux.host_view(); + } + if (kScaleOutput) { + file + << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() + << "\nComputed Absmax D = " << abs_max_D.host_view(); + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<2> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + auto ref_tA = [&](){ + if constexpr (IsSparseTestbed) { + cutlass::uncompress( + underlying_testbed.tensor_A_uncompressed.host_ref(), + underlying_testbed.tensor_A.host_ref(), + underlying_testbed.tensor_E.host_ref(), + problem_size.m(), + problem_size.k() + ); + return underlying_testbed.tensor_A_uncompressed.host_ref(); + } + else { + return underlying_testbed.tensor_A.host_ref(); + } + }(); + + // Run reference kernel with ElementOutput of type ElementAccumulator + // so that we can compute the absmax epilogue on data that is of type + // ElementAccumulator (which is what the GEMM we are testing will do). + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + problem_size, + scaled_alpha, + ref_tA, + Gemm::kTransformA, + underlying_testbed.tensor_B.host_ref(), + Gemm::kTransformB, + scaled_beta, + underlying_testbed.tensor_C.host_ref(), + tmp_D.host_ref(), + ElementAccumulator(0) + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + return underlying_testbed.sufficient(); + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Gemm::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + auto arguments = [&]() { + if constexpr (IsSparseTestbed) { + return typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + underlying_testbed.tensor_E_reordered.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + underlying_testbed.tensor_E_reordered.layout().stride(0), + tensor_Aux.layout().stride(0), + 0 // stride vector + }; + } + else { + return typename Gemm::Arguments{ + mode, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + 0, // stride vector + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + (int64_t)0 // Leading dimension of vector. This must be 0 + }; + } + }(); + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllGemmWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int constexpr kAlignmentM = [&]() { + if constexpr (std::is_same_v>) { + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + return std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + } + else { + return 128 / kMinimumOperandElementSize; + } + }(); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int M_problems[] = {kAlignmentM, 128 + 32}; + int N_problems[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + int K_problems[] = {Gemm::ThreadblockShape::kK * 2}; + double alpha_problems[] = {1.}; + double beta_problems[] = {0.}; + int split_k_slices[] = { + 1, 2 + }; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (int split_k : split_k_slices) { + if (cutlass::sizeof_bits_v <= 8 && split_k > 1) { + // Don't test split-K with FP8 output. The kernel being tested will writie partial accumulations + // for different splits to global memory in FP8, while the reference kernel will not. This leads + // to mismatches that are difficult to capture without a permissive relative equality check threshold. + continue; + } + + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + TestbedWithAmax testbed(scaleA, scaleB, scaleC); + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta << ", split_k:" << split_k; + + if (!passed) { + + return passed; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..8e939f9710403a5f5c3fd8c61e34c4e8021ff423 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h @@ -0,0 +1,358 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/core_io.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "cutlass/gemm/kernel/default_gemv.h" +#include "cutlass/gemm/kernel/gemv_batched_strided.h" + +namespace test { +namespace gemm { +namespace kernel { + +template +void batched_gemv_kernel_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + bool perf_test = false, + int perf_test_iter = 1) +{ + using ThreadBlockShape = ThreadBlockShape_; + using ThreadShape = ThreadShape_; + using ElementA = ElementAB_; + using LayoutA = LayoutA_; + using ElementB = ElementAB_; + using LayoutB = LayoutB_; + using ElementAccumulator = ElementCD_; + using ElementCD = ElementCD_; + using LayoutCD = LayoutCD_; + + using GemvKernel = cutlass::gemm::kernel::DefaultGemv; + + using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; + using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; + + if (DEBUG) + { + problem_size = cutlass::gemm::BatchedGemmCoord( + problem_size.m(), problem_size.n(), problem_size.k(), 1); + } + + // Create host tensors that will be the backing store for the batches + // Note that no device memory is initially allocated + cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); + cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); + + // Reserve memory for the batch of tensors + matrix_A.reserve(problem_size.m()*problem_size.k()*problem_size.batch()); + matrix_B.reserve(problem_size.n()*problem_size.k()*problem_size.batch()); + matrix_C_computed.reserve(problem_size.m()*problem_size.n()*problem_size.batch()); + matrix_C_reference.reserve(problem_size.m()*problem_size.n()*problem_size.batch(), false); + + // Fill eatch tensor batch + const int seed = 9876; + for (int b = 0; b < problem_size.batch(); b++) + { + if(DEBUG) + { + cutlass::reference::host::BlockFillSequential( + matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); + cutlass::reference::host::BlockFillSequential( + matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(b*matrix_A.capacity()), + seed + 1660, + 8, + -8, + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(b*matrix_B.capacity()), + seed + 1880, + 8, + -8, + 0 + ); + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); + cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + ThreadBlockSwizzle swizzle; + + cutlass::gemm::BatchedGemmCoord tiled_size{ThreadBlockShape::kM, + ThreadBlockShape::kN, + problem_size.k(), // no split-k + DEBUG ? 1 : THREAD_B }; + + cutlass::gemm::BatchedGemmCoord tiled_shape = swizzle.get_tiled_shape(problem_size, tiled_size); + + #if 0 + printf("tiled_size = %d %d %d %d\n", tiled_size.m(), tiled_size.n(), tiled_size.k(), tiled_size.batch()); + printf("tiled_shape = %d %d %d %d\n", tiled_shape.m(), tiled_shape.n(), tiled_shape.k(), tiled_shape.batch()); + #endif + + // No split-k + EXPECT_EQ(tiled_size.k(), problem_size.k()); + + dim3 grid = swizzle.get_grid_shape(tiled_shape); + dim3 block(tiled_size.n() / ThreadShape::kN, tiled_size.batch(), tiled_size.k() / problem_size.k()); + + // Some sanity checks + EXPECT_TRUE( block.x*block.y*block.z <= 1024 ); + EXPECT_TRUE( block.x <= 1024 ); + EXPECT_TRUE( block.y <= 1024 ); + EXPECT_TRUE( block.z <= 64 ); + + #if 0 + printf("grid dim = %d, %d, %d\n", grid.x, grid.y, grid.z); + printf("block dim = %d, %d, %d\n", block.x, block.y, block.z); + #endif + + cudaError_t result; + cudaEvent_t start_event, end_event; + + for (int iter = 0; iter < (perf_test ? (perf_test_iter+1) : 1); ++iter) + { + if (perf_test && iter == 1) + { + result = cudaEventCreate(&start_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventCreate(&end_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventRecord(start_event); + EXPECT_EQ(result, cudaSuccess); + } + + if (beta == ElementCD(0)) + { + if (alpha == ElementCD(1)) + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + beta, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + + if (iter == 0) + { + result = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + } + } + + if (perf_test) + { + result = cudaEventRecord(end_event); + EXPECT_EQ(result, cudaSuccess); + } + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + + if (perf_test) + { + float ms; + result = cudaEventElapsedTime(&ms, start_event, end_event); + EXPECT_EQ(result, cudaSuccess); + + double flops = (double(problem_size.m()) * + double(problem_size.n()) * + double(problem_size.k()) * + double(problem_size.batch()) * 2); // 2 for MAC + + double read_bytes = double(problem_size.batch()) * (sizeof(ElementA)*double(problem_size.m())*double(problem_size.k()) + + sizeof(ElementB)*double(problem_size.k())*double(problem_size.n())); + + double write_bytes = double(problem_size.batch()) * (sizeof(ElementCD)*double(problem_size.m())*double(problem_size.n())); + + double avg_runtime = double(ms) / perf_test_iter; + double gflops_per_sec = flops / 1.0e6 / avg_runtime; + double read_bandwidth = read_bytes / 1.0e6 / avg_runtime; + double write_bandwidth = write_bytes / 1.0e6 / avg_runtime; + + std::cout << "\n\nProblem size: " + << problem_size.m() + << " x " << problem_size.n() + << " x " << problem_size.k() + << " x " << problem_size.batch() + << std::endl; + + std::cout << " GFLOPs: " << gflops_per_sec << std::endl; + std::cout << "BW (R/W): " << read_bandwidth << " / " << write_bandwidth << " GB/sec" << std::endl; + std::cout << " Runtime: " << avg_runtime << " ms" << std::endl; + } + else + { + matrix_C_computed.sync_host(); + + // Compute the batched gemms + for (int b = 0; b < problem_size.batch(); b++) + { + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size.mnk(), alpha, + matrix_A.host_ref(b * matrix_A.capacity()), + matrix_B.host_ref(b * matrix_B.capacity()), beta, + matrix_C_reference.host_ref(b * matrix_C_computed.capacity())); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(b * matrix_C_computed.capacity()), + matrix_C_reference.host_view(b * matrix_C_reference.capacity())); + + EXPECT_TRUE(passed) + //<< "A:\n" << matrix_A.host_view() << "\n" + //<< "B:\n" << matrix_B.host_view() << "\n" + << "Batch: " << b << "\n" + << "Reference:\n" + << matrix_C_reference.host_view(b * matrix_C_reference.capacity()) + << "\n" + << "Computed:\n" + << matrix_C_computed.host_view(b * matrix_C_computed.capacity()) + << "\n"; + } + } +} + +template +void batched_gemv_kernel_perf_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + int iter = 50) +{ + batched_gemv_kernel_test(problem_size, alpha, beta, true, iter); +} + +} // namespace threadblock +} // namespace kernel +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h new file mode 100644 index 0000000000000000000000000000000000000000..6e3d6ab079d44345f2f55f4126ba3efc1eba47cb --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" +#include "cutlass/layout/vector.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + using Btype = typename Mma::ElementB; + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK), false); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN), false); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::detail::RandomUniformFunc< ElementA > tfill_rand_func( + 0, // seed + 10, // max + 0, // min + 0); // bits after decimal + + cutlass::reference::host::detail::TensorFillRandomUniformFunc< ElementA, LayoutA > tfill_rand( + tensor_A.host_view(), + tfill_rand_func); + + for (auto i=0; i< Shape::kM; i++) + for (auto j=0; j< Shape::kK; j++) + tfill_rand(cutlass::make_Coord(i,j)); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + + // Host side call + kernel( + tensor_D_computed.host_data(), + tensor_A.host_data(), + tensor_B.host_data(), + tensor_C.host_data()); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8d34d7992b57cefa0eaf7300a5e1fb49f41a93e2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..1f3bc8cf114d7eb2ac00bd19ae92c984558b7228 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -0,0 +1,435 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, + typename Mma::IteratorE::Params params_E, + typename Mma::IteratorE::TensorRef ref_E) { + // Shared storage needed by threadblock-scoped matrix multiply- + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k() / Mma::kSparse}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorE iterator_E( + params_E, ref_E.data(), + {problem_size.m(), + problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, + tb_thread_id, tb_offset_E); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct SparseTestbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ElementE = typename MmaCore::ElementE; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using ThreadMapE = typename MmaCore::IteratorThreadMapE; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + using AccessTypeE = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + static cutlass::arch::CacheOperation::Kind const CacheOpE = + MmaCore::kCacheOpE; + + static int const Sparse = MmaCore::kSparse; + static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; + static int const MaxID2 = MmaCore::kMaxID2; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename MmaCore::GmemLayoutE; + + static int const ElementsPerElementE = MmaCore::kElementsPerElementE; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define iterators over tiles from the E operand + using IteratorE = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, + typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_A_uncompressed; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_E; + cutlass::HostTensor matrix_E_reordered; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); + matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + matrix_E_reordered.reset( + cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + // Waive the test + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + matrix_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(matrix_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / Sparse / ElementsPerElementE}); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + matrix_E_reordered.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + typename IteratorE::Params params_E(matrix_E_reordered.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma_sparse + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), params_E, + matrix_E_reordered.device_ref()); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), + matrix_E.host_ref(), problem_size.m(), + problem_size.k()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm(problem_size, ElementC(alpha), + matrix_A_uncompressed.host_view(), matrix_B.host_view(), + ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "E:\n" << matrix_E.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..5caaf38ace92758bbc86970d8d4ff339d87348ab --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..4e617d6327594570b1a88a5b28f2ec4d0467b534 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h @@ -0,0 +1,387 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, CacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, LayoutC, + typename MmaCore::MmaPolicy, Stages>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributeMaxDynamicSharedMemorySize error: " + << cudaGetErrorString(result); + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributePreferredSharedMemoryCarveout error: " + << cudaGetErrorString(result); + } + + test::gemm::threadblock::kernel_multistage_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_multistage_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..7eb62f9a39fe4472f77446efc591267001758c58 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -0,0 +1,353 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_, + /// Number of stages + int Stages = 2> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + static const int kStages = Stages; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define MmaPipeline Single Stage + using MmaPipelineSingleStage = cutlass::gemm::threadblock::MmaSingleStage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define MmaPipeline Two Stages + using MmaPipelineTwoStages = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define the threadblock-scoped pipelined matrix multiply (Select between Single vs. Two stages) + using Mma = typename cutlass::platform::conditional<(kStages==1), MmaPipelineSingleStage, MmaPipelineTwoStages>::type; + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + bool sufficient() { + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..36e55b2542b2258542336a052cdd14bf4b85f78d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h @@ -0,0 +1,370 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e5fdc07769726353b33c1a5da65dedfadb4ce1e7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma_planar_complex( + cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::Element *ptr_A, + int64_t imaginary_stride_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::Element *ptr_B, + int64_t imaginary_stride_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { + + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A_real(params_A, ptr_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B_real(params_B, ptr_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum.real); + + iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename Mma_> +struct TestbedPlanarComplex { + + using Mma = Mma_; + using ThreadblockShape = typename Mma::Shape; + using IteratorA = typename Mma::IteratorA; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using IteratorB = typename Mma::IteratorB; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Mma::ElementC; + using ElementAccumulator = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + using ThreadMapA = typename Mma::IteratorA::ThreadMap; + using ThreadMapB = typename Mma::IteratorB::ThreadMap; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = Mma::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + Mma::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + Mma::kCacheOpB; + + // + // Data members + // + + cutlass::HostTensorPlanarComplex matrix_A; + cutlass::HostTensorPlanarComplex matrix_B; + cutlass::HostTensorPlanarComplex matrix_C_computed; + cutlass::HostTensorPlanarComplex matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedPlanarComplex(int m, int n, int k) + : problem_size(m, n, k) { + + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + + for (int i = 0; i < matrix_A.capacity() * 2; ++i) { + matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); + } + /* + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity() * 2); + */ + } else if (init_A == cutlass::Distribution::Identity) { + //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + + + } else if (init_B == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity() * 2); + + for (int i = 0; i < matrix_B.capacity() * 2; ++i) { + matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); + } + + + } else if (init_B == cutlass::Distribution::Identity) { + + //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + + } else { + return false; + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma_planar_complex<<>>( + problem_size, + params_A, + matrix_A.device_data(), + matrix_A.imaginary_stride(), + params_B, + matrix_B.device_data(), + matrix_B.imaginary_stride(), + matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), + matrix_C_computed.imaginary_stride() + ); + + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + cutlass::complex(ElementAccumulator(1)), + matrix_A.host_ref(), + Mma::kTransformA, + matrix_B.host_ref(), + Mma::kTransformB, + cutlass::complex(ElementAccumulator(0)), + matrix_C_reference.host_ref(), + matrix_C_reference.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), + matrix_C_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..921d1abdc40c2040104815cfffb8b2ea32384136 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h @@ -0,0 +1,1543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/platform/platform.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + + ++iter_A; + ++iter_B; + + mma(accum, frag_A, frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The inner product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), + tensor_A.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), + tensor_B.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] + << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] + << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel_transform( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + using TransformedFragmentA = typename Mma::TransformedFragmentA; + using TransformedFragmentB = typename Mma::TransformedFragmentB; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA loaded_frag_A; + FragmentB loaded_frag_B; + TransformedFragmentA transformed_frag_A; + TransformedFragmentB transformed_frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(loaded_frag_A); + iter_B.load(loaded_frag_B); + + ++iter_A; + ++iter_B; + + mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, + loaded_frag_B); + + mma(accum, transformed_frag_A, transformed_frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct TransformTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<>>( + tensor_D_computed.device_data(), tensor_A.device_data(), + tensor_B.device_data(), tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TransformedTestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformedTestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void sparse_kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + typename Mma::ElementE const *input_E, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer + smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementE, Mma::Shape::kM * Mma::Shape::kK / + Mma::kSparse / Mma::kElementsPerElementE> + smem_buffer_E; + + __syncthreads(); + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + + typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_E.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_E, i) = + cutlass::ReferenceFactory::type>::get(input_E, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + using FragmentE = typename Mma::FragmentE; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed( + {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); + typename Mma::LayoutB layout_B = + Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + typename Mma::LayoutE layout_E = + Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, + Mma::Shape::kK / Mma::kSparse / + Mma::kElementsPerElementE / Mma::kInterleaved}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + FragmentE frag_E; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + iter_E.load(frag_E); + + ++iter_A; + ++iter_B; + ++iter_E; + + mma(accum, frag_A, frag_B, accum, frag_E); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct SparseTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + static int const Sparse = Mma::kSparse; + static int const MetaSizeInBits = Mma::kMetaSizeInBits; + static int const MaxID2 = Mma::kMaxID2; + static int const Interleaved = Mma::kInterleaved; + + using ElementE = typename Mma::ElementE; + + static int const ElementsPerElementE = Mma::kElementsPerElementE; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = + cutlass::layout::ColumnMajorInterleaved; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, + ThreadblockShape::kK / Sparse)); + tensor_A_uncompressed.reset( + cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_E.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + tensor_E_reordered.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta( + tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_E_reordered.sync_device(); + + // launch kernel + sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_E_reordered.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), Shape::kM, Shape::kK); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..3311e915db892466a9a4c52c82d100c2e1319966 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h @@ -0,0 +1,43 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace nvrtc { + +extern char const *kCutlassHeaders[]; +extern char const *kCutlassHeaderNames[]; +extern size_t const kCutlassHeaderCount; +} // namespace nvrtc +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55df44379c847034ed38cfab23477331ee4a537c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + + +namespace nvrtc { +namespace thread { + +template< + typename ElementA, typename ElementB, typename ElementC, + typename TileShape, typename ClusterShape, + bool kTransA, bool kTransB, + int RANK_M, int RANK_N, int RANK_K, int RANK_L +> +struct ContractionKernel { + +using ElementScalar = float; +using ElementAccum = float; +using EpilogueThread = cutlass::epilogue::thread::LinearCombination; + +static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K; +static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + +/// Kernel config +typedef int64_t stride_type; +typedef int32_t extent_type; + +static constexpr const stride_type* stride_null = nullptr; +static constexpr const extent_type* extent_null = nullptr; + +template +static constexpr +auto +make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { + static_assert(Rank > 1); + if constexpr (IsMajor) { + return cute::transform(cute::make_seq{}, [&](auto i) { + if constexpr (i == 0) { + return cute::Int<1>{}; + } + else { + return i < n ? t[i] : init_default; + } + }); + } + else { + return cute::make_int_tuple(t, n, init_default); + } +} + +using StrideA = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideB = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideC = decltype(cute::make_stride( + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using ProblemShape = decltype(cute::make_shape( + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0))); + +using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 16 / sizeof(ElementA), + ElementB, StrideB, 16 / sizeof(ElementB), + ElementAccum, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized +>::CollectiveOp; + +using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue; +using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter; +using Kernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveOp, + CollectiveEpilogue>; + +}; + +} // namespace nvrtc +} // namespace thread diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..576f55cd868cd64c8c09c055d8b9a956e40c87ae --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/array.h" + +namespace test { +namespace nvrtc { +namespace kernel { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void testbed_kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +} +} +} +} + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h new file mode 100644 index 0000000000000000000000000000000000000000..c7e6e94691c82b2f343959421c884c8b0b06f9b4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h @@ -0,0 +1,30 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h new file mode 100644 index 0000000000000000000000000000000000000000..5ba5432fd568af71e15b20b8cdab1571f303bcdf --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +typedef char int8_t; +typedef unsigned char uint8_t; +typedef short int16_t; +typedef unsigned short uint16_t; +typedef int int32_t; +typedef unsigned int uint32_t; +typedef long long int int64_t; +typedef unsigned long long int uint64_t; + +#if defined __x86_64__ && !defined __ILP32__ +# define __WORDSIZE 64 +#else +# define __WORDSIZE 32 +#endif + + +/* Small types. */ + +/* Signed. */ +typedef signed char int_least8_t; +typedef short int int_least16_t; +typedef int int_least32_t; +#if __WORDSIZE == 64 +typedef long int int_least64_t; +#else +__extension__ +typedef long long int int_least64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_least8_t; +typedef unsigned short int uint_least16_t; +typedef unsigned int uint_least32_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_least64_t; +#else +__extension__ +typedef unsigned long long int uint_least64_t; +#endif + + +/* Fast types. */ + +/* Signed. */ +typedef signed char int_fast8_t; +#if __WORDSIZE == 64 +typedef long int int_fast16_t; +typedef long int int_fast32_t; +typedef long int int_fast64_t; +#else +typedef int int_fast16_t; +typedef int int_fast32_t; +__extension__ +typedef long long int int_fast64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_fast8_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_fast16_t; +typedef unsigned long int uint_fast32_t; +typedef unsigned long int uint_fast64_t; +#else +typedef unsigned int uint_fast16_t; +typedef unsigned int uint_fast32_t; +__extension__ +typedef unsigned long long int uint_fast64_t; +#endif + +/* Types for `void *' pointers. */ +#if __WORDSIZE == 64 +# ifndef __intptr_t_defined +typedef long int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned long int uintptr_t; +#else +# ifndef __intptr_t_defined +typedef int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned int uintptr_t; +#endif + + +/* Largest integral types. */ +#if __WORDSIZE == 64 +typedef long int intmax_t; +typedef unsigned long int uintmax_t; +#else +__extension__ +typedef long long int intmax_t; +__extension__ +typedef unsigned long long int uintmax_t; +#endif + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8fd6863e8fa003d3fbc4e0b498e3b9b454ade190 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h @@ -0,0 +1,398 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/gemm/thread/mma.h" +#include "../kernel/thread/testbed_kernel.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/trace.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include +#include +#include "../cutlass/nvrtc/environment.h" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace nvrtc { +namespace thread { + +#define NVRTC_RETURN_IF_ERROR(api) \ + do { \ + nvrtcResult _result = api; \ + if (_result != NVRTC_SUCCESS) { \ + CUTLASS_TRACE_HOST("Nvrtc error: " << _result); \ + return false; \ + } \ + } while(0) + +inline const char * cuda_source_fmt = R"""( + +#include "kernel/thread/contraction.hpp" + +using Operator = %s; + +extern "C" __global__ void global_entry(__grid_constant__ Operator::Params const params) { + extern __shared__ char smem[]; + + Operator op; + op(params, smem); +} + +)"""; + +struct TestbedKernel { + static bool compile(std::string const &kernel, std::vector const &opts) { + int sz = std::snprintf(nullptr, 0, cuda_source_fmt, kernel.c_str()); + std::vector cuda_source(sz + 1); + std::snprintf(&cuda_source[0], cuda_source.size(), cuda_source_fmt, kernel.c_str()); + + nvrtcProgram program; + NVRTC_RETURN_IF_ERROR( + nvrtcCreateProgram( + &program, + cuda_source.data(), + nullptr, + static_cast(cutlass::nvrtc::kCutlassHeaderCount), + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames) + ); + + nvrtcResult compile_result = + nvrtcCompileProgram( + program, + static_cast(opts.size()), + opts.data()); + + size_t log_size; + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLogSize(program, &log_size) + ); + + if (log_size > 1) { + auto log = std::make_unique(log_size); + + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLog(program, log.get()) + ); + + std::cout << log.get() << std::endl; + } + + NVRTC_RETURN_IF_ERROR(compile_result); + + NVRTC_RETURN_IF_ERROR( + nvrtcDestroyProgram(&program) + ); + + return true; + } +}; + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + static inline bool check_nvrtc_error(nvrtcResult error) { + if (error != NVRTC_SUCCESS) { + std::cerr << "failed to compile "; + return false; + } + return true; + } + + /// Runs the test + bool run(std::string const &gemm_traits) { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + +#if 0 + // launch kernel + cutlass::gemm::kernel::testbed_kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + +#else + // Instantiate gemm_kernel + nvrtcResult result_nvrtc; + nvrtcProgram program; + static char const *src = + "#include \"cutlass/gemm/thread/mma.h\"\n" + "#include \"cutlass/gemm/gemm.h\"\n" + "#include \"cutlass/layout/matrix.h\"\n" + "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n" + ; + + std::string type_name; +#if 0 + // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names + // As altername solution we might want to implement to_string() to get the traits string. + nvrtcGetTypeName(&type_name); +#else + type_name = gemm_traits; +#endif + + result_nvrtc = nvrtcCreateProgram(&program, + src, + NULL, + (int)cutlass::nvrtc::kCutlassHeaderCount, + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames); + check_nvrtc_error(result_nvrtc); + + std::string gemm_kernel_instantiation = + "test::nvrtc::kernel::thread::testbed_kernel< " + type_name + " >"; + nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); + + const char *opts[] = {"--gpu-architecture=compute_75", + "--std=c++17", + "--include-path=/usr/local/cuda-10.1/include"}; + + result_nvrtc = nvrtcCompileProgram(program, 3, opts); + if (result_nvrtc != NVRTC_SUCCESS) { + size_t logSize; + nvrtcGetProgramLogSize(program, &logSize); + std::vector log(logSize); + nvrtcGetProgramLog(program, log.data()); + std::cout << "Compile log:" << std::endl << log.data() << std::endl; + } + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // The lowered name is the name of the template instantiation in the generated PTX code. + char const *gemm_kernel_lowered_name; + nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(), &gemm_kernel_lowered_name); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards + size_t ptx_size; + result_nvrtc = nvrtcGetPTXSize(program, &ptx_size); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + std::vector ptx(ptx_size); + result_nvrtc = nvrtcGetPTX(program, ptx.data()); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // we do not need the nvrtc program anymore + //nvrtcDestroyProgram(&program); + + CUmodule module; + CUresult result_cuda; + result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + CUfunction kernel; + result_cuda = cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + void* d_a = (void*)tensor_A.device_data(); + void* d_b = (void*)tensor_B.device_data(); + void* d_c = (void*)tensor_C.device_data(); + void* d_d = (void*)tensor_D_computed.device_data(); + void* args[] = { &d_d, &d_a, &d_b, &d_c }; + + // CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra + result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0 /*cudaStreamDefault*/, args, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } else { +} +#endif + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cout << "CUDA ERROR: " << cudaGetErrorString(result); + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + if(!passed) std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + std::cout << "passed " << passed << std::endl; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace nvrtc +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6cc2946a2c51cfb8c1971345c81c1910bd667208 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h @@ -0,0 +1,145 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Common Testbed file shared by Pipeline unit tests +*/ + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" +#include "../common/cutlass_unit_test.h" + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +// Command line test options +struct Options { + // + // Data Members + // + bool help; + bool verification_enabled; + int SM_count; + int clock_MHz; + + // + // Methods + // + Options(): + help(false), + verification_enabled(true), + SM_count(116), + clock_MHz(1477) + { } + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); + cmd.get_cmd_line_argument("sm-count", SM_count, 116); + cmd.get_cmd_line_argument("clock", clock_MHz, 1477); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +struct Testbed { +private: + // Commandline options + Options options; + + void run_test(uint32_t const kNumIters) { + + // Run CuTe Gemm + Pipeline pipeline; + + cudaError_t result = pipeline.run(kNumIters); + + CUTE_CHECK_LAST(); + } + + +public: + Testbed(Options const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + + std::array kNumIters; + + for (size_t i = 0; i < kNumIters.size(); ++i) { + kNumIters[i] = static_cast( (rand() % 1000) + 1 ); + } + + for (int n : kNumIters) { + std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; + run_test(n); + } + + return true; + } +}; diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h new file mode 100644 index 0000000000000000000000000000000000000000..50a68a1437956c95aa4e7912e93adc8b1481c9cc --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed file used by cluster launch control pipeline unit test +*/ + +// + +// + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +// Command line test options +struct OptionsClusterLaunch { + // + // Data Members + // + bool help = false; + bool verification_enabled = true; + int SM_count = 116; + int clock_MHz = 1477; + dim3 grid_dim = {0,0,0}; + + // + // Methods + // + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, verification_enabled); + cmd.get_cmd_line_argument("sm-count", SM_count, SM_count); + cmd.get_cmd_line_argument("clock", clock_MHz, clock_MHz); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +class TestbedClusterLaunch { +private: + // Commandline options + OptionsClusterLaunch options; + + bool run_test() { + + // Run CuTe Gemm + Pipeline pipeline; + + bool success = false; + cudaError_t result = pipeline.run(success, this->options.grid_dim); + + CUTE_CHECK_LAST(); + return success; + } + + +public: + TestbedClusterLaunch(OptionsClusterLaunch const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + +#if !defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + printf( + "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be set, but it is not. \n" + "This test is waived.\n" + ); + return true; +#endif + +#if 0 + bool is_success = false; + for (int i = 0; i< 10; i++){ + printf("iteration = %d\n", i); + is_success = run_test(); + if ( not is_success ) + return is_success; + } + return is_success; +#else + // Run the test with single launch + return run_test(); +#endif + } +}; diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e44a42463ae95e4f76388d791c661de875092c93 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..239f228831a25527106af1659383112535943df1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +namespace test { +namespace reduction { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_host { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + // + // Data members + // + + cutlass::Array tensor_in; + cutlass::Array reduced_tensor_computed; + cutlass::Array reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_host() { + tensor_in.clear(); + reduced_tensor_computed.clear(); + reduced_tensor_reference.clear(); + } + + /// Runs the test + bool run() { + + // + // initialize memory + // + + for(int i = 0; i < N; i++) + tensor_in.at(i) = Element(i); + + + Reduce reduce; + + cutlass::Array *out_ptr = &reduced_tensor_computed; + out_ptr[0] = reduce(tensor_in); + + // + // Reference implementation + // + Element e(0); + for (int i = 0; i < N; i++) + e = e + Element(i); + + reduced_tensor_reference.at(0) = e; + + // + // Verify equivalence + // + + // compare + bool passed = reduced_tensor_reference[0] == reduced_tensor_computed[0]; + + EXPECT_TRUE(passed) + << "Expected = " << float(reduced_tensor_reference.at(0)) << "\n\n" + << "Actual = " << float(reduced_tensor_computed.at(0)) << "\n\n" + << std::endl; + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level reduction kernel +template +__global__ void kernel_reduce(Element const *array_in, Element *result) { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + Reduce reduce; + + auto ptr_in = reinterpret_cast const *>(array_in); + auto result_ptr = reinterpret_cast *>(result); + auto in = *ptr_in; + result_ptr[0] = reduce(in); +} + + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_device { + + using Layout = cutlass::layout::PackedVectorLayout; + + // + // Data members + // + + cutlass::HostTensor tensor_in; + cutlass::HostTensor reduced_tensor_computed; + cutlass::HostTensor reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_device() { + + tensor_in.reset(cutlass::make_Coord(N), true); + reduced_tensor_computed.reset(cutlass::make_Coord(1), true); + reduced_tensor_reference.reset(cutlass::make_Coord(1), true); + } + + + /// Runs the test + bool run() { + + // + // initialize memory + // + + cutlass::reference::host::TensorFill( + tensor_in.host_view(), + Element(1) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_computed.host_view(), + Element(0) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_reference.host_view(), + Element(N) + ); + + tensor_in.sync_device(); + reduced_tensor_computed.sync_device(); + reduced_tensor_reference.sync_device(); + + /// call the kernel + kernel_reduce<<< dim3(1, 1), dim3(1, 1, 1) >>> ( + tensor_in.device_data(), + reduced_tensor_computed.device_data() + ); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + // Copy back results + reduced_tensor_computed.sync_host(); + + // Verify equivalence + bool passed = cutlass::reference::host::TensorEquals( + reduced_tensor_computed.host_view(), + reduced_tensor_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "Expected = " << reduced_tensor_reference.host_view() << "\n\n" + << "Actual = " << reduced_tensor_computed.host_view() << "\n\n" + << std::endl; + + return passed; + } +}; + +} // namespace thread +} // namespace reduction +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4e7de4351076dba3a699b4cb1c8a6e01485bc20 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils specific for SM90 structure sparse kernels +*/ + +#pragma once + +#include // std::fill +#include // std::array +#include +#include // std::mt19937 + +#include "cute/container/bit_field.hpp" // cute::bit_field +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor +#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter + +namespace cutlass +{ +namespace transform +{ +namespace kernel +{ + +using namespace cute; + +namespace detail { + + template + CUTLASS_HOST_DEVICE + static uint8_t + encode_in_chunk_idx_legacy(int in_chunk_idx){ + if (sizeof(T) == 4) { + return in_chunk_idx == 0 ? 0b0100 : 0b1110; + } + else { + uint8_t res = 0; + if (in_chunk_idx == 0) { + res = 0b00; + } + else if (in_chunk_idx == 1) { + res = 0b01; + } + else if (in_chunk_idx == 2) { + res = 0b10; + } + else { + res = 0b11; + } + return res; + } + } + + template < + class SparseConfig, + class EngineA, + class LayoutA, + class EngineAc, + class LayoutAc + > + CUTLASS_HOST_DEVICE + static void + compress_two_chunks_legacy( + Tensor tensorA, + Tensor tensorAc, + uint8_t& meta_two_chunk, + int effective_elems) { + + using ElementA = typename EngineAc::value_type; + + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int ElementEBitsPerElementAMma = typename SparseConfig::ElementEBitsPerElementAMma{}; + static constexpr int LogicalSubChunk = ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalSubChunk = ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + /* + Legal metadata chunk in SM90 + Index Bin HEX + 0, 1 0b0100 4 + 1, 2 0b1001 9 + 2, 3 0b1110 E + 0, 2 0b1000 8 + 1, 3 0b1101 D + 0, 3 0b1100 C + 2, 1 0b0110 6 (Not used) + ----------------------------------- + TF32 + 0 0b0100 4 + 1 0b1110 E + */ + + if (effective_elems <= 0) { + return; + } + + // initialize + // 0 is the initial value for this function while 0x44 is the initial value for hardware. + meta_two_chunk = 0; + + for (int chunk_idx = 0; chunk_idx < 2; ++chunk_idx) { + // If Only One Chunk within this Two Chunk + if ( effective_elems <= chunk_idx * ElemsARawPerElementAMmaRaw * LogicalSubChunk ) { + break; + } + /// init result; + int non_zero_cnt = 0; + int32_t nnz_chunk_idx[PhysicalSubChunk] = { 0 }; + ElementA Ac_chunk[PhysicalSubChunk][ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + + for (int subchunk_idx = 0; subchunk_idx < LogicalSubChunk; ++subchunk_idx) { + bool is_nz = true; + ElementA subchunk_elems[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + /// Check if subchunk is non-zero + for(int elem_idx = 0; elem_idx < ElemsARawPerElementAMmaRaw; elem_idx++) { + int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx; + subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0); + + ElementA zero = static_cast(0); + ElementA minus_zero = static_cast(ElementA(1) << cutlass::sizeof_bits_v - 1); + if (subchunk_elems[elem_idx] != zero && subchunk_elems[elem_idx] != minus_zero) { + if (non_zero_cnt >= PhysicalSubChunk) { + #ifdef __CUDA_ARCH__ + asm volatile ("brkpt;\n" ::); + #else + throw std::runtime_error("Found extra non-zero elements in a chunk!\n"); + #endif + } + is_nz = false; + } + } + + /// There is non-zero element in the subchunk + if(!is_nz) { + nnz_chunk_idx[non_zero_cnt] = subchunk_idx; + memcpy(Ac_chunk[non_zero_cnt], subchunk_elems, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + non_zero_cnt++; + } + } + + /* + Special cases + nnz == 1 and non-tf32 and nnz_idx = 3 + */ + ElementA elementA_zeros[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + if constexpr (sizeof_bits_v < 32) { + if (non_zero_cnt == 1 && nnz_chunk_idx[0] == 3) { + memcpy(Ac_chunk[1], Ac_chunk[0], sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + memcpy(Ac_chunk[0], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + nnz_chunk_idx[0] = 0; + } + else if (non_zero_cnt == 1) { + memcpy(Ac_chunk[1], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + } + } + + /// Setup metadata + uint8_t meta_chunk = 0; + for (int i = 0; i < PhysicalSubChunk; i++) { + meta_chunk = static_cast(meta_chunk | (encode_in_chunk_idx_legacy(nnz_chunk_idx[i]) << (i * ElementEBitsPerElementAMma))); + for(int j = 0; j < ElemsARawPerElementAMmaRaw; j++) { + tensorAc(chunk_idx * PhysicalElemsAPerChunk + i * ElemsARawPerElementAMmaRaw + j) = Ac_chunk[i][j]; + } + } + meta_two_chunk = uint8_t(meta_two_chunk | (meta_chunk << (chunk_idx * _4{}))); + } + } +} + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class SM90StructuredSparseCompressorLegacy { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + // * EltA + using ElementA = ElementA_; + using ElementAUint = cute::uint_bit_t>; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + using ArrayElementA = cute::conditional_t>, + ElementA>; + using ElementAMma = typename SparseConfig::ElementAMma; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + using LayoutATag = LayoutATag_; + using LayoutA = LayoutATag; + using StrideA = cutlass::gemm::TagToStrideA_t; + + // * EltE + using ElementEMma = typename SparseConfig::ElementEMma; + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + + // * AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + // * Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + // Required by `device_kernel` + static constexpr int MaxThreadsPerBlock = 1; + static constexpr int MinBlocksPerMultiprocessor = 1; + using ArchTag = arch::Sm90; + + struct SharedStorage { + /* empty, no smem needed */ + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct TransformArguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementA* ptr_ACompress{nullptr}; + ElementEMmaRaw* ptr_E{nullptr}; + }; + + using TransformParams = TransformArguments; + + struct Arguments { + ProblemShape problem_shape{}; + TransformArguments transform{}; + KernelHardwareInfo hw_info{}; + }; + + struct Params { + ProblemShape problem_shape{}; + TransformParams transform{}; + KernelHardwareInfo hw_info{}; + void* workspace = nullptr; + }; + + static Params + to_underlying_arguments(Arguments & args, void* workspace) { + return Params{{args.problem_shape}, + {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, + {args.hw_info}, + workspace}; + } + + static Status + can_implement(Arguments const& args) { + auto [M, N, K, L] = args.problem_shape; + if (K % LogicalElemsAPerChunk != 0) { + CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size\n"); + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + auto problem = args.problem_shape; + const int m = cute::size<0>(problem); + const int k = cute::size<2>(problem); + const int l = cute::size<3>(problem); + const int metadata_k = round_up(k, TensorEAlignmentK); + const int metadata_m = round_up(m, TensorEAlignmentM); + const int metadata_bytes = metadata_m * metadata_k / ElementEMmaSparsity{} * l; + return metadata_bytes; + } + + static Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + cudaError_t cuda_error; + + auto workspace_size = get_workspace_size(args); + if (workspace_size == 0) { + return Status::kSuccess; + } else if (workspace == nullptr) { + return Status::kErrorInternal; + } + + cudaPointerAttributes attri; + cuda_error = cudaPointerGetAttributes(&attri, workspace); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } + + if ( attri.type == cudaMemoryTypeDevice ) { +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + CUTLASS_ASSERT(cuda_adapter); + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { + return Status::kErrorInternal; + } +#else + cudaMemsetAsync(workspace, 0, workspace_size, stream); + cuda_error = cudaGetLastError(); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } +#endif + } else { + memset(workspace, 0, workspace_size); + } + + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + return dim3(1, 1, 1); + } + + static dim3 + get_block_shape() { + return dim3(1, 1, 1); + } + + CUTE_HOST_DEVICE + void + operator()(Params params, char* smem_buf = nullptr) { + run(params, smem_buf); + } + + CUTE_HOST_DEVICE + static void + run(Params params, char* smem_buf = nullptr) { + do_compress_device_host(params); + } + +private: + + CUTE_HOST_DEVICE + static void + do_compress_device_host(Params params) { + auto [m, n, k, l] = params.problem_shape; + auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; + auto workspace = params.workspace; + + const int aligned_k = (k + TensorAAlignmentK - 1) / TensorAAlignmentK * TensorAAlignmentK; + const int aligned_m = (m + TensorAAlignmentM - 1) / TensorAAlignmentM * TensorAAlignmentM; + const int metadata_k = (k + TensorEAlignmentK - 1) / TensorEAlignmentK * TensorEAlignmentK; + const int metadata_m = (m + TensorEAlignmentM - 1) / TensorEAlignmentM * TensorEAlignmentM; + const int k_compressed = aligned_k / ElementASparsity{}; + + // Convert to CuTe tensors. But don't want to use sparse_ptr, which is making everything complicated here. + cute::Tensor tensorA = make_tensor(recast_ptr(ptr_A), make_layout(make_shape(m, k, l), dA)); + + cute::Tensor tensorAc = make_tensor(recast_ptr(ptr_ACompress), + make_shape(aligned_m, k_compressed, l), + make_cute_packed_stride(StrideA{}, cute::make_shape(aligned_m, k_compressed, l))); + + cute::Tensor tensorE_raw_compress_logical = make_tensor(recast_ptr>(workspace), + make_shape(metadata_m, make_shape(TensorEAtomK{}, metadata_k / TensorEAtomK{}), l), + make_stride(TensorEAtomK{}, make_stride(_1{}, metadata_m*TensorEAtomK{}), metadata_m*metadata_k)); + + cute::Tensor tensorE_raw_compress = recast(tensorE_raw_compress_logical); + + // The following vars are all logical. + int atom_m = size<0>(TensorEAtom{}); + int atom_k = size<1>(TensorEAtom{}); + int tiled_m = metadata_m / atom_m; + int tiled_ke = metadata_k / atom_k; + // Col major when viewing atoms + int stride_tile_m = cosize(TensorEAtom{}); + int stride_tile_ke = atom_k * metadata_m; + + // Logical metadata tensor + cute::Tensor tensorE_logical = make_tensor(recast_ptr>(ptr_E), + make_layout(make_shape(append(shape<0>(TensorEAtom{}), tiled_m), + append(shape<1>(TensorEAtom{}), tiled_ke), + shape<2>(tensorE_raw_compress_logical)), + make_stride(append(stride<0>(TensorEAtom{}), stride_tile_m), + append(stride<1>(TensorEAtom{}), stride_tile_ke), + stride<2>(tensorE_raw_compress_logical)))); + // Physical metadata tensor + cute::Tensor tensorE = recast(tensorE_logical); + + // void do_init() + cute::clear(tensorAc); + cute::clear(tensorE_raw_compress); + + // void do_raw_compress() + using TileStepA = Int; + using TileStepAc = Int; + + cute::Tensor tensorATiled = logical_divide(tensorA, make_shape(_, TileStepA{}, _)); + cute::Tensor tensorAcTiled = logical_divide(tensorAc, make_shape(_, TileStepAc{}, _)); + + for (int batch_idx = 0; batch_idx < l; batch_idx++) { + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int tiler_k_idx = 0; tiler_k_idx < size<1,1>(tensorATiled); tiler_k_idx++) { + int effective_elems = cute::min(TileStepA{}, k - (tiler_k_idx * TileStepA{})); + detail::compress_two_chunks_legacy(tensorATiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorAcTiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorE_raw_compress(m_idx, tiler_k_idx, batch_idx), + effective_elems); + } + } + } + + // void do_reorder() + // Fast path when we don't permute. + if constexpr (sizeof_bits_v <= 8) { + memcpy(tensorE.data(), tensorE_raw_compress.data(), tensorE.size()); + } + else { + cute::copy(tensorE_raw_compress, tensorE); + } + + #if 0 + print("--> TensorA\n"); + auto tensorA_eltA = cute::recast(tensorA); + cute::print_tensor(tensorA_eltA); printf("\n\n"); + + print("--> REF TensorAC\n"); + auto tensorAc_eltA = cute::recast(tensorAc); + cute::print_tensor(tensorAc_eltA); printf("\n\n"); + + print("--> REF TensorE\n"); + cute::print_tensor(tensorE); printf("\n\n"); + #endif + + } +}; + +} // namespace kernel +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f44458244e0d3c4c80ecc29a0115cd6906211559 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp @@ -0,0 +1,877 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + * @brief Test for structured sparse gemm compressor device kernel + */ + +#pragma once + +#include // cudaGetLastError + +#include // uint64_t +#include // printf +#include // malloc +#include // std::cout +#include +#include + +#include "cute/layout.hpp" // cute::make_shape +#include "cute/util/type_traits.hpp" // cute::is_same_v +#include "cutlass/coord.h" // cutlass::make_Coord +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/kernel_hardware_info.hpp" // cutlass::KernelHardwareInfo +#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory +#include "cutlass/numeric_types.h" // cutlass::sizeof_bits, cutlass::float_ +#include "cutlass/tensor_view.h" // cutlass::TensorView +#include "cutlass/transform/device/transform_universal_adapter.hpp" // cutlass::transform::device::TransformUniversalAdapter +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // cutlass::transform::kernel::StructuredSparseCompressorUtility +#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation +#include "cutlass/util/distribution.h" // cutlass::Distribution +#include "cutlass/util/host_tensor.h" // cutlass::HostTensor +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals +#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill +#include "cutlass/detail/collective.hpp" + +#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor +#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE + + +#define CUDA_CHECK_FALSE(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return false; \ + } \ + } + +#define CUDA_CHECK(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return; \ + } \ + } + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// * Test Bed +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test +{ +namespace transform +{ +namespace device +{ + +// Helper Functions +template +bool +initialize_tensor(cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) +{ + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else { + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Testbed +template +struct TestbedSparseGemmCompressor { +public: + using Compressor = Compressor_; + using CompressorKernel = typename Compressor::TransformKernel; + + using ElementA = typename CompressorKernel::ElementA; + using LayoutATag = typename CompressorKernel::LayoutATag; + using StrideA = typename CompressorKernel::StrideA; + using ArrayElementA = + ElementA + ; + + using ElementE = typename CompressorKernel::ElementEMmaRaw; + using LayoutETag = cutlass::layout::RowMajor; // We don't care about the major here, just to allocate tensor + + using SparseConfig = typename CompressorKernel::SparseConfig; + using ProblemShapeType = typename CompressorKernel::ProblemShape; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorKernelHost = cutlass::transform::kernel::SM90StructuredSparseCompressorLegacy< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorHost = cutlass::transform::device::TransformUniversalAdapter; + + static constexpr auto LogicalElemsAPerChunk = CompressorKernel::LogicalElemsAPerChunk; + static constexpr auto PhysicalElemsAPerChunk = CompressorKernel::PhysicalElemsAPerChunk; + + struct Data { + // Data Storage + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_A_Comp_ref; + cutlass::HostTensor tensor_E_ref; + }; + + struct CudaRAII { + cudaStream_t stream; + cudaEvent_t start; + cudaEvent_t stop; + + CudaRAII(){ + CUDA_CHECK(cudaStreamCreate( &stream )); + CUDA_CHECK(cudaEventCreate( &start )); + CUDA_CHECK(cudaEventCreate( &stop )); + }; + + CudaRAII(const CudaRAII&) = delete; + CudaRAII& operator=(const CudaRAII&) = delete; + CudaRAII(CudaRAII&&) = delete; + CudaRAII& operator=(CudaRAII&&) = delete; + + ~CudaRAII(){ + CUDA_CHECK(cudaStreamDestroy( stream )); + CUDA_CHECK(cudaEventDestroy( start )); + CUDA_CHECK(cudaEventDestroy( stop )); + } + }; + +public: + TestbedSparseGemmCompressor( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_Comp_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 7) + : init_A(init_A_) + , init_E(init_E_) + , init_A_Comp(init_A_Comp_) + , seed(seed_) + { + } + + bool valid_test(ProblemShapeType problem_shape_MNKL) + { + const int GemmK = cute::size<2>(problem_shape_MNKL); + + if ( GemmK % LogicalElemsAPerChunk != 0 ) { + printf("GemmK needs to be multiplier of LogicalElemsAPerChunk\n"); + return false; + } + + return true; + } + + bool initialize(ProblemShapeType problem_shape_MNKL, Data& datas) + { + CUDA_CHECK_FALSE(cudaGetLastError()); + + // In unit of ElementARaw + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + // Compressor utility to get allocated data size + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorA + // In unit of ElementARaw, after alignment requirement + // M-dim: no alignment requirement + // K-dim: multiplier of chunk size + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + const int GemmMAlignedAC = compressor_utility.get_tensorA_m_physical(); + const int GemmKAlignedAC = compressor_utility.get_tensorA_k_physical(); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + const int GemmMAlignedE = compressor_utility.get_metadata_m_physical(); + const int GemmKAlignedE = compressor_utility.get_metadata_k_physical(); + + auto a_coord = cutlass::make_Coord(GemmM * GemmL, GemmK); + auto e_coord = cutlass::make_Coord(GemmMAlignedE * GemmL, GemmKAlignedE); + auto a_comp_coord = cutlass::make_Coord(GemmMAlignedAC * GemmL, GemmKAlignedAC); + + typename LayoutATag::Stride stride_factor_A; + typename LayoutETag::Stride stride_factor_E; + + datas.tensor_A.resize(a_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + datas.tensor_A_Comp.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + datas.tensor_A_Comp_ref.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A), + false); + datas.tensor_E.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + datas.tensor_E_ref.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E), + false); + + EXPECT_TRUE(initialize_tensor(datas.tensor_A.host_view(), init_A, seed + 1)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E.host_view(), init_E, seed + 2)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E_ref.host_view(), init_E, seed + 3)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp.host_view(), init_A_Comp, seed + 4)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp_ref.host_view(), init_A_Comp, seed + 5)); + + compressor_utility.structure_sparse_zero_mask_fill(datas.tensor_A.host_data(), seed + 6); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + datas.tensor_A.sync_device(); + datas.tensor_A_Comp.sync_device(); + datas.tensor_E.sync_device(); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + return true; + } + + bool run_device(ProblemShapeType problem_shape_MNKL, Data& datas, float* time = nullptr) + { + CudaRAII cuda_raii; + + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.device_data(), + stride_a, + datas.tensor_A_Comp.device_data(), + datas.tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + status = compressor_op.initialize(arguments, workspace.get(), cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.start, cuda_raii.stream)); + + status = compressor_op.run(cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.stop, cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventSynchronize(cuda_raii.stop)); + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + if ( time != nullptr ){ + CUDA_CHECK_FALSE(cudaEventElapsedTime(time, cuda_raii.start, cuda_raii.stop)); + } + + datas.tensor_A_Comp.sync_host(); + datas.tensor_E.sync_host(); + + #if 0 + { + printf("\n--> DEVICE OUTPUT\n"); + printf("datas.tensor_A\n"); + std::cout << datas.tensor_A.host_view() << std::endl << std::endl; + printf("datas.tensor_A_Comp\n"); + std::cout << datas.tensor_A_Comp.host_view() << std::endl << std::endl; + printf("datas.tensor_E\n"); + std::cout << datas.tensor_E.host_view() << std::endl << std::endl; + } + #endif + + return true; + } + + bool run_host_ref(ProblemShapeType problem_shape_MNKL, Data& datas) + { + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + typename CompressorKernelHost::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.host_data(), + stride_a, + datas.tensor_A_Comp_ref.host_data(), + datas.tensor_E_ref.host_data()}, + {}}; + + const auto can_imp = CompressorKernelHost::can_implement(arguments); + if (can_imp != cutlass::Status::kSuccess) { + printf("can_implement() check failed\n"); + return false; + } + + // Relies on std::vector for RAII + auto workspace_size = + static_cast::size_type>(CompressorKernelHost::get_workspace_size(arguments)); + std::vector workspace_vector(workspace_size); + auto workspace = static_cast(workspace_vector.data()); + + cutlass::Status status = CompressorKernelHost::initialize_workspace(arguments, workspace); + if (status != cutlass::Status::kSuccess) { + printf("initialize_workspace() failed\n"); + return false; + } + + auto params = CompressorKernelHost::to_underlying_arguments(arguments, workspace); + CompressorKernelHost::run(params); + + return true; + } + + bool compare_reference(Data& datas) + { + bool check_tensor_a_compressed = + cutlass::reference::host::TensorEquals(datas.tensor_A_Comp_ref.host_view(), datas.tensor_A_Comp.host_view()); + if (!check_tensor_a_compressed) { + printf("A-Compressed Mismatch\n"); + } + + bool check_tensor_e = cutlass::reference::host::TensorEquals(datas.tensor_E_ref.host_view(), datas.tensor_E.host_view()); + if (!check_tensor_e) { + printf("E Mismatch\n"); + } + + return check_tensor_a_compressed && check_tensor_e; + } + + bool run_auto_small() + { + return run_auto(true); + } + + bool run_auto(bool run_small = false) + { + constexpr auto TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + constexpr auto TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + + constexpr int GemmN = 1; + + using ProblemType = typename std::array; + + std::vector problems; + + const std::vector problems_multiplier_of_tensor_e_atom = { + // * Regular Cases (multiplier of TensorEAlignment) + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 3}, + }; + + const std::vector problems_multiplier_of_tensor_e_atom_large = { + // * Large Case (multiplier of TensorEAlignment) + {TensorEAlignmentM * 10, GemmN, TensorEAlignmentK * 13, 1}, + // {TensorEAlignmentM * 11, GemmN, TensorEAlignmentK * 14, 2}, + // {TensorEAlignmentM * 12, GemmN, TensorEAlignmentK * 15, 3}, + }; + + const std::vector problems_multiplier_of_twochunk { + // * Corner Cases + {4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + }; + + const std::vector problems_multiplier_of_onechunk { + {4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + }; + + // Run small only run multiplier of chunk size cases + if (run_small) { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + } + // Run full run all corner cases + else { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom_large.begin(), problems_multiplier_of_tensor_e_atom_large.end()); + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + problems.insert(problems.end(), problems_multiplier_of_twochunk.begin(), problems_multiplier_of_twochunk.end()); + problems.insert(problems.end(), problems_multiplier_of_onechunk.begin(), problems_multiplier_of_onechunk.end()); + } + + for (const auto& problem_shape_MNKL : problems) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + bool passed = run({GemmM, GemmN, GemmK, GemmL}); + printf("run() (%.4d,%.4d,%.4d,%.4d) %s\n", GemmM, GemmN, GemmK, GemmL, passed ? "PASS" : "FAIL"); + CUTLASS_TRACE_HOST("run() " << GemmM << " " << GemmN << " " << GemmK << " " << GemmL << passed ? " PASS" : " FAIL"); + if (not passed) { + return false; + } + } + + return true; + } + + bool run(ProblemShapeType problem_shape_MNKL) + { + // Check if valid test + if (not valid_test(problem_shape_MNKL)) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // Data Storage + Data datas; + + // Initialize Data + if (not initialize(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Host Ref) + if (not run_host_ref(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_host() fail\n"); + return false; + } + + // Run Compressor (Device) + if (not run_device(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + // Verify + if (not compare_reference(datas)) { + CUTLASS_TRACE_HOST("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + printf("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + return false; + } + // else { + // printf("DEVICE <-> HOST PASS\n"); + // } + + return true; + } + + bool benchmark(ProblemShapeType problem_shape_MNKL) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) START\n", GemmM, GemmN, GemmK, GemmL); + + // Check if valid test + if (valid_test(problem_shape_MNKL) == false) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // 2 warm-up iterations and 10 timing iterations + constexpr int num_warmup = 5; + constexpr int num_iter = 10; + + // Duplicate data to mimic cold cache + Data data[num_warmup + num_iter]; + double total_time_milliseconds{0.0}; + + for (int i = 0; i < num_warmup + num_iter; ++i ) { + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) ITER %d\n", GemmM, GemmN, GemmK, GemmL, i ); + + auto& datum_i = data[i]; + + // Initialize Data + if (initialize(problem_shape_MNKL, datum_i) == false) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Device) + double time_i_milliseconds{0.0f}; + if (not run_device(problem_shape_MNKL, datum_i, &time_i_milliseconds)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + if ( i >= num_warmup ) { + total_time_milliseconds += time_i_milliseconds; + } + } + + const double mean_time_milliseconds = total_time_milliseconds / num_iter; + printf("Mean time (ms): %.5f\n", mean_time_milliseconds); + + return true; + } + +public: + // Data Init Setting + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_A_Comp; + cutlass::Distribution::Kind init_E; + uint64_t seed; +}; + +} // namespace device +} // namespace transform +} // namespace test diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h new file mode 100644 index 0000000000000000000000000000000000000000..df241e3ca6e6e584af7351402d990a8028e2abed --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/arch/arch.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ArchMap; + +template <> struct ArchMap { + static int const kMin = 50; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 60; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 61; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 75; +}; + +template struct ArchMap { + static int const kMin = 75; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 80; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 86; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 89; + static int const kMax = 100; +}; + +template struct ArchMap { + static int const kMin = 90; + static int const kMax = 1024; +}; + +// Arch conditional WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + +// Arch conditional sparse WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + + +template struct ArchMap { + static int const kMin = 100; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 100; + #if (__CUDACC_VER_MAJOR__ >= 13) + static int const kMax = 110; + #else + static int const kMax = 103; + #endif // __CUDACC_VER_MAJOR__ >= 13 +}; + +template struct ArchMap { + static int const kMin = 103; + static int const kMax = 1024; +}; +template <> struct ArchMap { + static int const kMin = 103; + static int const kMax = 103; +}; + +template struct ArchMap { + static int const kMin = 120; + static int const kMax = 121; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h new file mode 100644 index 0000000000000000000000000000000000000000..5e80c124e59d24cd90c7c1b0c06bcc3bedfee62f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h @@ -0,0 +1,815 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct MathInstructionDescription { + + /// Shape of the target math instruction + cutlass::gemm::GemmCoord instruction_shape; + + /// Describes the data type of the internal accumulator + NumericTypeID element_accumulator; + + /// Classification of math instruction + OpcodeClassID opcode_class; + + /// Type of math operation performed + MathOperationID math_operation; + + // + // Methods + // + + MathInstructionDescription( + cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), + NumericTypeID element_accumulator = NumericTypeID::kInvalid, + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd + ): + instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} + + // Equality operator + inline + bool operator==(MathInstructionDescription const& rhs) const{ + return ( + (instruction_shape == rhs.instruction_shape) && + (element_accumulator == rhs.element_accumulator) && + (opcode_class == rhs.opcode_class) && + (math_operation == rhs.math_operation)); + } + + // Inequality operator + inline + bool operator!=(MathInstructionDescription const& rhs) const { + return !(*this == rhs); + } + +}; + +/// Structure describing the tiled structure of a GEMM-like computation +struct TileDescription { + + /// Describes the shape of a threadblock (in elements) + cutlass::gemm::GemmCoord threadblock_shape; + + /// Describes the number of pipeline stages in the threadblock-scoped mainloop + int threadblock_stages; + + /// Number of warps in each logical dimension + cutlass::gemm::GemmCoord warp_count; + + /// Core math instruction + MathInstructionDescription math_instruction; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int minimum_compute_capability; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int maximum_compute_capability; + + /// Describes the shape of a cluster (in blocks) + cutlass::gemm::GemmCoord cluster_shape; + + // + // Methods + // + + TileDescription( + cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), + int threadblock_stages = 0, + cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), + MathInstructionDescription math_instruction = MathInstructionDescription(), + int minimum_compute_capability = 0, + int maximum_compute_capability = 0, + cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) + ): + threadblock_shape(threadblock_shape), + threadblock_stages(threadblock_stages), + warp_count(warp_count), + math_instruction(math_instruction), + minimum_compute_capability(minimum_compute_capability), + maximum_compute_capability(maximum_compute_capability), + cluster_shape(cluster_shape) { } + + // Equality operator + inline + bool operator==(TileDescription const& rhs) const{ + return ( + (threadblock_shape == rhs.threadblock_shape) && + (threadblock_stages == rhs.threadblock_stages) && + (warp_count == rhs.warp_count) && + (math_instruction == rhs.math_instruction) && + (minimum_compute_capability == rhs.minimum_compute_capability) && + (maximum_compute_capability == rhs.maximum_compute_capability)); + } + + // Inequality operator + inline + bool operator!=(TileDescription const& rhs) const { + return !(*this == rhs); + } +}; + +/// High-level description of an operation +struct OperationDescription { + + /// Unique identifier describing the operation + char const * name; + + /// Operation provider + Provider provider; + + /// Kind of operation + OperationKind kind; + + /// Describes the tiled structure of a GEMM-like computation + TileDescription tile_description; + + // + // Methods + // + OperationDescription( + char const * name = "unknown", + Provider provider = Provider::kInvalid, + OperationKind kind = OperationKind::kInvalid, + TileDescription const& tile_description = TileDescription() + ): + name(name), provider(provider), kind(kind), tile_description(tile_description) { } +}; + +/// Structure describing the properties of a tensor +struct TensorDescription { + + /// Numeric type of an individual element + NumericTypeID element; + + /// Enumerant identifying the layout function for the tensor + LayoutTypeID layout; + + /// Alignment restriction on pointers, strides, and extents + int alignment; + + /// log2() of the maximum extent of each dimension + int log_extent_range; + + /// log2() of the maximum value each relevant stride may have + int log_stride_range; + + // + // Methods + // + + TensorDescription( + NumericTypeID element = NumericTypeID::kInvalid, + LayoutTypeID layout = LayoutTypeID::kInvalid, + int alignment = 1, + int log_extent_range = 24, + int log_stride_range = 24 + ): + element(element), + layout(layout), + alignment(alignment), + log_extent_range(log_extent_range), + log_stride_range(log_stride_range) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all GEMM computations +struct GemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the sparse meta matrices + TensorDescription E; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + GemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + GemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +struct BlockScaleDescription { + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + /// Describes the underlying kind of scaling: + /// Tensor Core supported (BlockScaled) or manual scaling (Blockwise) + OperationKind kind; +}; + +struct GroupedGemmDescription : public OperationDescription { + GemmDescription gemm; + std::optional block_scales; +}; + +/// Description of all GEMM computations +struct BlockScaledGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + // + // Methods + // + + BlockScaledGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockScaledGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +/// Description of all GEMM computations +struct BlockwiseGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + // + // Methods + // + + BlockwiseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockwiseGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description for structured sparse GEMMs. +struct SparseGemmDescription : public GemmDescription { + + /// Description structure for structured sparse GEMM + SparseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + TensorDescription const& E = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) + {this->E = E;} +}; + +/// Description of all Reduction operations +struct ReductionDescription : public OperationDescription { + + /// Describes the data type of workspace + NumericTypeID element_workspace; + + /// Describes the data type of final output + NumericTypeID element_output; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; +}; + +/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) +struct RankKDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + RankKKind rank_k_kind; + + /// Number of rank update (rank k or rank 2k) + int num_ranks; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand (used only for SYR2K and HER2K) + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the fill mode for matrix C + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + RankKDescription( + RankKKind rank_k_kind = RankKKind::kUniversal, + int num_ranks = 1, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + rank_k_kind(rank_k_kind), + num_ranks(num_ranks), + A(A), + B(B), + C(C), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all TRMM computations +struct TrmmDescription : public OperationDescription { + + /// Indicates the kind of TRMM performed + TrmmKind trmm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the diag type for matrix A + DiagType diag_type; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription D; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + // + // Methods + // + + TrmmDescription( + TrmmKind trmm_kind = TrmmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + DiagType diag_type = DiagType::kInvalid, + TensorDescription const& B = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone + ): + trmm_kind(trmm_kind), + A(A), + side_mode(side_mode), + fill_mode(fill_mode), + diag_type(diag_type), + B(B), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all SYMM/HEMM update computations +struct SymmDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + SymmKind symm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + SymmDescription( + SymmKind symm_kind = SymmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + symm_kind(symm_kind), + A(A), + B(B), + C(C), + side_mode(side_mode), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all Conv2d operations +struct ConvDescription : public OperationDescription { + /// Describes the convolution dimension support (2D or 3D) + int conv_dim; + + /// Describes the kind of convolution + ConvKind conv_kind; + + /// Describes the type of iterator algorithm (analytic or precomputed) + IteratorAlgorithmID iterator_algorithm; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the C operand + TensorDescription C; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + // + // Methods + // + // Returns Activation TensorDescription + TensorDescription activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return A; + case library::ConvKind::kDgrad : return C; + case library::ConvKind::kWgrad : return B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter TensorDescription + TensorDescription filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return B; + case library::ConvKind::kDgrad : return B; + case library::ConvKind::kWgrad : return C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output TensorDescription + TensorDescription output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return C; + case library::ConvKind::kDgrad : return A; + case library::ConvKind::kWgrad : return A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h new file mode 100644 index 0000000000000000000000000000000000000000..027944eb6ac8c6e8f250d83ed33c0899adfbd3e8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h @@ -0,0 +1,365 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief BLAS-like handle used to launch operations on the CUDA device. +*/ + +#pragma once + +#include +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Handle object +class Handle { +private: + + /// Host workspace + static int const kHostWorkspaceSize = (4 << 10); + + /// Provider of operations + Provider provider_; + + /// CUDA device properties + cudaDeviceProp device_; + + /// CUDA stream + cudaStream_t stream_; + + /// Device workspace + void *workspace_; + + /// Size of device workspace in bytes + size_t workspace_size_; + + /// Indicates whether scalars are host or device pointers + ScalarPointerMode scalar_pointer_mode_; + + /// Pointer to the most recently executed operation + Operation const *last_operation_; + + int device_idx_; + +public: + + /// Constructor + Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); + + /// Destructor + ~Handle(); + + /// Move constructor + Handle(Handle && handle); + + /// Move assignment operator + Handle &operator=(Handle && handle); + + // + // Persistent state accessors + // + + /// Returns compute capability of the selected device + int compute_capability() const; + + /// Sets the current CUDA stream + void set_stream(cudaStream_t stream); + + /// Gets the current CUDA stream + cudaStream_t get_stream() const; + + /// Gets the current provider + Provider get_provider() const; + + /// Sets the provider of operations + void set_provider(Provider provider); + + /// Gets the device workspace size + size_t get_workspace_size() const; + + /// Gets a pointer to the device workspace allocation in Global Memory + void *get_workspace() const; + + /// Sets the size of device workspace, invalidating calls to get_device_workspace() + void set_workspace_size(size_t bytes); + + /// Gets the scalar pointer mode + ScalarPointerMode get_scalar_pointer_mode() const; + + /// Sets the scalar pointer mode + void set_scalar_pointer_mode(ScalarPointerMode mode); + + /// Gets the most recently executed operation + Operation const *get_last_operation() const; + + // + // Computations + // + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C + Status gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int64_t ldd /// Leading dimension of D matrix + ); + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C. + // + // Supports batched-strided, batched array or split-K serial or split-K parallel. + // + Status gemm_universal( + + GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + int cluster_m, /// cluster shape M dimension + int cluster_n, /// cluster shape N dimension + int cluster_k, /// cluster shape K dimension + int cluster_m_fallback, /// Fallback cluster shape M dimension + int cluster_n_fallback, /// Fallback cluster shape N dimension + int cluster_k_fallback, /// Fallback cluster shape K dimension + + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C matrix + LayoutTypeID layout_C, /// Layout of D matrix + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + NumericTypeID element_D, /// Data type of D matrix + LayoutTypeID layout_D, /// Layout of D matrix + void * ptr_D, /// Pointer to D matrix + int64_t ldd, /// Leading dimension of D matrix + + int batch_count = 1, /// Batch count or number of split-K slices + + int64_t batch_stride_A = 0, /// Batch stride of A operand + int64_t batch_stride_B = 0, /// Batch stride of B operand + int64_t batch_stride_C = 0, /// Batch stride of C operand + int64_t batch_stride_D = 0 /// Batch stride of D operand + ); + + /// Planar complex GEMM + /// + /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. + /// + Status gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count = 1, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real = 0, + int64_t batch_stride_A_imag = 0, + + int64_t batch_stride_B_real = 0, + int64_t batch_stride_B_imag = 0, + + int64_t batch_stride_C_real = 0, + int64_t batch_stride_C_imag = 0, + + int64_t batch_stride_D_real = 0, + int64_t batch_stride_D_imag = 0 + ); + + /// Planar complex GEMM loading pointers from arrays in global memory + Status gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing pointers to imaginary part of C matrices + + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing pointers to imaginary part of D matrices + + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag /// Leading dimension of imaginary part of D matrix + ); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Unique pointer storing the handle +using HandlePtr = std::unique_ptr; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace +Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace +Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h new file mode 100644 index 0000000000000000000000000000000000000000..6764d9a6d81286c8bba0f5184b17819bfae86978 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h @@ -0,0 +1,995 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#ifndef CUTLASS_LIBRARY_LIBRARY_H +#define CUTLASS_LIBRARY_LIBRARY_H + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/types.h" +#include "cutlass/library/descriptions.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/blas3.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mode of Universal GEMM +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for all operations +class Operation { +public: + + virtual ~Operation() { } + + virtual OperationDescription const & description() const = 0; + + virtual Status can_implement( + void const *configuration, + void const *arguments) const = 0; + + virtual uint64_t get_host_workspace_size( + void const *configuration) const = 0; + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const = 0; + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Originally designed for metadata, but should be useful for FP8/6/4 too. + virtual Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspace_ptrs, + int problem_count, + cudaStream_t stream = nullptr) { + return Status::kErrorNotSupported; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + virtual Status initialize_with_arguments(void* arguments_ptr) const { + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic GEMM operations +// +// OperationKind: Gemm +// GemmKind: Gemm +// +struct GemmConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Number of partitions of K dimension + int split_k_slices{0}; +}; + +/// Arguments for GEMM +struct GemmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Batched + +struct GemmBatchedConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Stride between instances of the A matrix in memory + int64_t batch_stride_A{0}; + + /// Stride between instances of the B matrix in memory + int64_t batch_stride_B{0}; + + /// Stride between instances of the C matrix in memory + int64_t batch_stride_C{0}; + + /// Stride between instances of the D matrix in memory + int64_t batch_stride_D{0}; + + /// Number of GEMMs in batch + int batch_count{1}; +}; + +/// Arguments to batched GEMM +using GemmBatchedArguments = GemmArguments; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Array + +struct GemmArrayConfiguration { + + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + int batch_count{1}; +}; + +/// Arguments for GEMM - used by all the GEMM operations +struct GemmArrayArguments { + void const * const *A{nullptr}; + void const * const *B{nullptr}; + void const * const *C{nullptr}; + void * const *D{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex +// +// OperationKind: Gemm +// GemmKind: Universal + +struct GemmUniversalConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int device_count{1}; +}; + +enum class Sm90MixedInputWiderOperand { + A = 0, + B = 1 +}; + +struct GemmUniversalArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{}; + library::RuntimeDatatype runtime_input_datatype_b{}; + int swizzle_size{1}; + int split_k_slices{1}; + + // For SM90 mixed input dtype kernels + bool is_sm90_mixed_dtype{false}; + Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B}; + bool generate_scale_and_zero{false}; + bool generate_dequantized_AB{false}; + void *Scale{nullptr}; // Scale tensor + void *Zero{nullptr}; // Zero tensor + void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + void *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + int device_index{0}; + + bool use_pdl{false}; +}; + +/// Block Scaled GEMM +// +// OperationKind: kBlockScaledGemm +// GemmKind: Universal + +struct BlockScaledGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + void *SFD{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for ScaleFactor Generation + void const *norm_constant{nullptr}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + +/// Blockwise GEMM +// +// OperationKind: kBlockwiseGemm +// GemmKind: Universal + +struct BlockwiseGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + int sf_m_vec_size{0}; + int sf_n_vec_size{0}; + int sf_k_vec_size{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Complex valued GEMM in which real and imaginary parts are separated by a stride +// +// OperationKind: Gemm +// GemmKind: Planar complex + +struct GemmPlanarComplexConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArguments { + + void const *A_real{nullptr}; + void const *A_imag{nullptr}; + void const *B_real{nullptr}; + void const *B_imag{nullptr}; + void const *C_real{nullptr}; + void const *C_imag{nullptr}; + void *D_real{nullptr}; + void *D_imag{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A_real{0}; + int64_t batch_stride_A_imag{0}; + int64_t batch_stride_B_real{0}; + int64_t batch_stride_B_imag{0}; + int64_t batch_stride_C_real{0}; + int64_t batch_stride_C_imag{0}; + int64_t batch_stride_D_real{0}; + int64_t batch_stride_D_imag{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This is a special form of planar complex which loads pointers and problem size +/// from memory. +struct GemmPlanarComplexArrayConfiguration { + + gemm::GemmCoord problem_size{}; + int batch_count{1}; + + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArrayArguments { + + int const *M{nullptr}; + int const *N{nullptr}; + int const *K{nullptr}; + + void const * const * A_real{nullptr}; + void const * const * A_imag{nullptr}; + void const * const * B_real{nullptr}; + void const * const * B_imag{nullptr}; + void const * const * C_real{nullptr}; + void const * const * C_imag{nullptr}; + void * const * D_real{nullptr}; + void * const * D_imag{nullptr}; + + void const * alpha{nullptr}; + void const * beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Grouped GEMM supporting +// +// OperationKind: Gemm +// GemmKind: Grouped + +struct GemmGroupedConfiguration { + int problem_count{0}; + // GemmGroupedConfiguration is passed to initialize(), which + // is responsible for allocating the device-side stride storage. + int64_t* lda; + int64_t* ldb; + int64_t* ldc; + + cute::Shape* problem_sizes_3x_host; +}; + +struct GemmGroupedArguments { + int problem_count{}; + gemm::GemmCoord* problem_sizes{nullptr}; + + void* ptr_A{nullptr}; + void* ptr_B{nullptr}; + void* ptr_C{nullptr}; + void* ptr_D{nullptr}; + + int64_t* lda{nullptr}; + int64_t* ldb{nullptr}; + int64_t* ldc{nullptr}; + int64_t* ldd{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; + + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + int swizzle_size{1}; + + // these should really be in the configuration but staying consistent with GEMM + int sm_count{0}; + int max_active_clusters{0}; + + // The user is responsible for allocating storage for problem sizes. + // Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we + // unfortunately need to have both options in this struct, and the + // underlying operation uses the one it needs. + cute::Shape* problem_sizes_3x; + cute::Shape* problem_sizes_3x_host; +}; + +struct GroupedGemmBlockScaledArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; + void* SFD{nullptr}; + void* norm_constant{nullptr}; +}; + +struct GroupedGemmBlockwiseArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// OperationKind: kSparseGemm +// + +/// Computes GEMM assuming one of the inputs has 2:4 structured sparsity. +struct SparseGemmConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; /// number of sparse matrix products in batch + int64_t lda{0}; /// leading dimension of A operand + int64_t ldb{0}; /// leading dimension of B operand + int64_t ldc{0}; /// leading dimension of C operand + int64_t ldd{0}; /// leading dimension of D operand + int64_t lde{0}; /// leading dimension of E operand (metadata matrix) + int64_t batch_stride_A{0}; // stride between matrices + int64_t batch_stride_B{0}; // stride between matrices + int64_t batch_stride_C{0}; // stride between matrices + int64_t batch_stride_D{0}; // stride between matrices + int64_t batch_stride_E{0}; // stride between matrices +}; + +/// Arguments for sparse GEMMs +struct SparseGemmArguments { + void const *A{nullptr}; /// pointer to A matrix + void const *B{nullptr}; /// pointer to B matrix + void const *C{nullptr}; /// pointer to C matrix + void *D{nullptr}; /// pointer to D matrix + void const *E{nullptr}; /// pointer to E matrix (metadata) + void const *alpha{nullptr}; /// pointer to alpha scalar + void const *beta{nullptr}; /// pointer to beta scalar + ScalarPointerMode pointer_mode{}; /// enumerant indicating whether alpha/beta pointers are host + /// or device pointers. + bool use_pdl{false}; /// Whether to use PDL when launching the kernel +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic Rank K update operations +// +// OperationKind: (Syrk, Herk, Syr2k, Her2k) +// RankKKind: Universal +// +struct RankKConfiguration { + + /// SYRK problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Syrk, Herk, Syr2k, Her2k) +struct RankKArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix (used only for Syr2k and Her2k) + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic TRMM operations +// +// OperationKind: Trmm +// TrmmKind: Universal +// +struct TrmmConfiguration { + + /// TRMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for TRMM +struct TrmmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic SYMM/HEMM update operations +// +// OperationKind: (Symm, Hemm) +// SymmKind: Universal +// +struct SymmConfiguration { + + /// SYMM/HEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Symm, Hemm) +struct SymmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Two dimensional convolution +// +// OperationKind: Conv2d +// +struct Conv2dConfiguration { + + conv::SplitKMode split_k_mode; + + /// Conv2d problem size + // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv2dProblemSize problem_size{}; + + // stride of operand A + std::vector stride_a{}; + + // stride of operand B + std::vector stride_b{}; + + // stride of operand C + std::vector stride_c{}; +}; + + +/// Three dimensional convolution +// +// OperationKind: Conv3d +// +struct Conv3dConfiguration { + + conv::SplitKMode split_k_mode{}; + + /// Conv2d problem size + // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv3dProblemSize problem_size{}; + + /// Layout object for activations tensor + layout::TensorNDHWC layout_activations{}; + + /// Layout object for filters tensor + layout::TensorNDHWC layout_filters{}; + + /// Layout object for source tensor + layout::TensorNDHWC layout_source{}; + + /// Layout object for output tensor + layout::TensorNDHWC layout_output{}; + + // + // Methods + // + + // Mapping functions (A,B,C -> activation,filter,output) + layout::TensorNDHWC layout_a(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_activations; + case library::ConvKind::kDgrad: return layout_output; + case library::ConvKind::kWgrad: return layout_output; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_b(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_filters; + case library::ConvKind::kDgrad: return layout_filters; + case library::ConvKind::kWgrad: return layout_activations; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_c(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_output; + case library::ConvKind::kDgrad: return layout_activations; + case library::ConvKind::kWgrad: return layout_filters; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +/// Arguments for CONV +struct ConvArguments { + + ///////////////////////////////////////////////////////// + /// ImplicitGemm matrices A, B, C, D + ///////////////////////////////////////////////////////// + /// pointer to implicit gemm matrix A + void const *A{nullptr}; + + /// pointer to implicit gemm matrix B + void const *B{nullptr}; + + /// pointer to reordered matrix B + void const *reordered_B{nullptr}; + + /// pointer to implicit gemm matrix C + void const *C{nullptr}; + + /// pointer to implicit gemm destination matrix D + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for Reduction operations +// +// OperationKind: Reduction +// +struct ReductionConfiguration { + + /// Reduction problem size + MatrixCoord problem_size{}; + + /// Number of partitions to reduce + int partitions{0}; + + /// Number of elements between each partition + int64_t partition_stride{0}; + + /// leading dimension of 'w'orkspace operand + int64_t ldw{0}; + + /// leading dimension of 's'ource operand + int64_t lds{0}; + + /// leading dimension of 'd'estination operand + int64_t ldd{0}; +}; + +/// Arguments for Reduction +struct ReductionArguments { + + /// Pointer to workspace matrix + void const *workspace{nullptr}; + + /// Pointer to source matrix + void const *source{nullptr}; + + /// Pointer to destination matrix + void *destination{nullptr}; + + /// pointer to reference matrix + void *reference{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h new file mode 100644 index 0000000000000000000000000000000000000000..c4fb0ee8ca32124450b1063cc3613078e600479d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Manifest of CUTLASS Library + + This is the root of the data structure containing CUTLASS objects +*/ + +#pragma once + +#include +#include +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "library.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Forward declaration +class Manifest; + +// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) +void initialize_all(Manifest &manifest); + +// init and insert all reduction op in manifest object (manually instantiated in library/reduction) +void initialize_all_reduction_op(Manifest &manifest); + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +/// List of operations +using OperationVector = std::vector>; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Manifest of CUTLASS Library +class Manifest { +private: + + /// Operation provider + Provider provider_; + + /// Global list of operations + OperationVector operations_; + +public: + Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } + + /// Top-level initialization + Status initialize(); + + /// Used for initialization + void reserve(size_t operation_count); + + /// Graceful shutdown + Status release(); + + /// Appends an operation and takes ownership + void append(Operation *operation_ptr) {\ + // This function is inline s.t. it is present in generated libraries + // without having to compile or link in manifest.cpp + operations_.emplace_back(operation_ptr); + } + + /// Returns an iterator to the first operation + OperationVector const &operations() const; + + /// Returns a const iterator + OperationVector::const_iterator begin() const; + + /// Returns a const iterator + OperationVector::const_iterator end() const; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h new file mode 100644 index 0000000000000000000000000000000000000000..f36232c8dc833e2b24d681686f6662e79b7ecd0a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h @@ -0,0 +1,905 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#pragma once +#include +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct GemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + ComplexTransform transform_A; + NumericTypeID element_B; + LayoutTypeID layout_B; + ComplexTransform transform_B; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + + // + // Methods + // + + inline + GemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + ComplexTransform transform_A = ComplexTransform::kNone, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + ComplexTransform transform_B = ComplexTransform::kNone, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor + ): + provider(provider), + gemm_kind(gemm_kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + transform_A(transform_A), + element_B(element_B), + layout_B(layout_B), + transform_B(transform_B), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D) + { } + + inline + bool operator==(GemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (transform_A == rhs.transform_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (transform_B == rhs.transform_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D); + } + + inline + bool operator!=(GemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " transform_A: " << to_string(k.transform_A) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " transform_B: " << to_string(k.transform_B) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for GemmFunctionalKey +struct GemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(GemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.element_compute)), 3) ^ + rotl(hash(int(key.element_scalar)), 4) ^ + rotl(hash(int(key.element_A)), 5) ^ + rotl(hash(int(key.layout_A)), 6) ^ + rotl(hash(int(key.transform_A)), 7) ^ + rotl(hash(int(key.element_B)), 8) ^ + rotl(hash(int(key.layout_B)), 9) ^ + rotl(hash(int(key.transform_B)), 10) ^ + rotl(hash(int(key.element_C)), 11) ^ + rotl(hash(int(key.layout_C)), 12) ^ + rotl(hash(int(key.element_D)), 13) ^ + rotl(hash(int(key.layout_D)), 14); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for GEMM operators +struct GemmPreferenceKey { + + int compute_capability; + int alignment; + + // + // Methods + // + + GemmPreferenceKey(): compute_capability(), alignment() { } + + GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } + + bool operator<(GemmPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); + } + + bool operator==(GemmPreferenceKey const &rhs) const { + return compute_capability == rhs.compute_capability; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::GemmPreferenceKey& key) { + out << "{\n" + << "compute_capability : " << key.compute_capability << std::endl + << "alignment : " << key.alignment << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using GemmOperationVectorMap = std::map< + GemmPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using GemmOperationFunctionalMap = std::unordered_map< + GemmFunctionalKey, + GemmOperationVectorMap, + GemmFunctionalKeyHasher +>; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for BlockScaled Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockScaledGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + NumericTypeID element_SFD; + LayoutTypeID layout_SFD; + int SFVecSize; + int EpilogueSFVecSize; + // + // Methods + // + + inline + BlockScaledGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockScaledGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFD = NumericTypeID::kF16, + LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor, + int sf_vec_size = 32 + , int epilogue_sf_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + element_SFD(element_SFD), + layout_SFD(layout_SFD), + SFVecSize(sf_vec_size) + , EpilogueSFVecSize(epilogue_sf_vec_size) + { } + + inline + bool operator==(BlockScaledGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (element_SFD == rhs.element_SFD) && + (layout_SFD == rhs.layout_SFD) && + (SFVecSize == rhs.SFVecSize) + && (EpilogueSFVecSize == rhs.EpilogueSFVecSize) + ; + } + + inline + bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " element_SFD: " << to_string(k.element_SFD) << "\n" + << " layout_SFD: " << to_string(k.layout_SFD) << "\n" + << " SFVecSize: " << k.SFVecSize << "\n" + << "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockScaledGemmFunctionalKeyHasher +struct BlockScaledGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockScaledGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.element_SFD)), 16) ^ + rotl(hash(int(key.layout_SFD)), 17) ^ + rotl(hash(int(key.SFVecSize)), 18) ^ + rotl(hash(int(key.EpilogueSFVecSize)), 19) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockScaledGemmOperationFunctionalMap = std::unordered_map< + BlockScaledGemmFunctionalKey, + GemmOperationVectorMap, + BlockScaledGemmFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Blockwise Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockwiseGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + // + // Methods + // + + inline + BlockwiseGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockwiseGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + int sfm_vec_size = 32, + int sfn_vec_size = 32, + int sfk_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + SFMVecSize(sfm_vec_size), + SFNVecSize(sfn_vec_size), + SFKVecSize(sfk_vec_size) + { } + + inline + bool operator==(BlockwiseGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (SFMVecSize == rhs.SFMVecSize) && + (SFNVecSize == rhs.SFNVecSize) && + (SFKVecSize == rhs.SFKVecSize); + } + + inline + bool operator!=(BlockwiseGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockwiseGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " SFMVecSize: " << k.SFMVecSize << "\n" + << " SFNVecSize: " << k.SFNVecSize << "\n" + << " SFKVecSize: " << k.SFKVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockwiseGemmFunctionalKeyHasher +struct BlockwiseGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockwiseGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.SFMVecSize)), 16) ^ + rotl(hash(int(key.SFNVecSize)), 17) ^ + rotl(hash(int(key.SFKVecSize)), 18) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockwiseGemmOperationFunctionalMap = std::unordered_map< + BlockwiseGemmFunctionalKey, + GemmOperationVectorMap, + BlockwiseGemmFunctionalKeyHasher +>; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Conv Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying conv2d functional behavior +struct ConvFunctionalKey { + library::Provider provider; + library::ConvKind conv_kind; + library::NumericTypeID element_A; + library::LayoutTypeID layout_A; + library::NumericTypeID element_B; + library::LayoutTypeID layout_B; + library::NumericTypeID element_C; + library::LayoutTypeID layout_C; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_compute; + + + // + // Methods + // + + inline + ConvFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::ConvKind conv_kind = library::ConvKind::kFprop, + library::NumericTypeID element_A = library::NumericTypeID::kF16, + library::LayoutTypeID layout_A = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_B = library::NumericTypeID::kF16, + library::LayoutTypeID layout_B = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_C = library::NumericTypeID::kF16, + library::LayoutTypeID layout_C = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_compute = library::NumericTypeID::kF32 + ): + provider(provider), + conv_kind(conv_kind), + element_A(element_A), + layout_A(layout_A), + element_B(element_B), + layout_B(layout_B), + element_C(element_C), + layout_C(layout_C), + element_accumulator(element_accumulator), + element_compute(element_compute) + { } + + inline + bool operator==(ConvFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (conv_kind == rhs.conv_kind) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_accumulator == rhs.element_accumulator) && + (element_compute == rhs.element_compute); + } + + inline + bool operator!=(ConvFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctionalKey& key) { + out << "{\n" + << "provider: " << to_string(key.provider) << std::endl + << "conv_kind: " << to_string(key.conv_kind) << std::endl + << "element_A: " << to_string(key.element_A) << std::endl + << "layout_A: " << to_string(key.layout_A) << std::endl + << "element_B: " << to_string(key.element_B) << std::endl + << "layout_B: " << to_string(key.layout_B) << std::endl + << "element_C: " << to_string(key.element_C) << std::endl + << "layout_C: " << to_string(key.layout_C) << std::endl + << "element_accumulator: " << to_string(key.element_accumulator) << std::endl + << "element_compute: " << to_string(key.element_compute) << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +struct ConvFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ConvFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.conv_kind)), 2) ^ + rotl(hash(int(key.element_A)), 3) ^ + rotl(hash(int(key.layout_A)), 4) ^ + rotl(hash(int(key.element_B)), 5) ^ + rotl(hash(int(key.layout_B)), 6) ^ + rotl(hash(int(key.element_C)), 7) ^ + rotl(hash(int(key.layout_C)), 8) ^ + rotl(hash(int(key.element_accumulator)), 9) ^ + rotl(hash(int(key.element_compute)), 10); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for Conv2d operators +struct ConvPreferenceKey { + + int compute_capability; + IteratorAlgorithmID iterator_algorithm; + + + // + // Methods + // + + ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } + + ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): + compute_capability(cc), iterator_algorithm(iterator_algorithm) { } + + bool operator<(ConvPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); + } + + bool operator==(ConvPreferenceKey const &rhs) const { + return (compute_capability == rhs.compute_capability) && + (iterator_algorithm == rhs.iterator_algorithm); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using ConvOperationVectorMap = std::map< + ConvPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using ConvOperationFunctionalMap = std::unordered_map< + ConvFunctionalKey, + ConvOperationVectorMap, + ConvFunctionalKeyHasher +>; +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Tuple uniquely identifying conv2d functional behavior +struct ReductionFunctionalKey { + library::Provider provider; + library::NumericTypeID element_workspace; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_output; + library::NumericTypeID element_compute; + library::MathOperationID reduce_math_op; + library::EpilogueKind epilogue_math_op; + + + // + // Methods + // + + inline + ReductionFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::NumericTypeID element_workspace = library::NumericTypeID::kF16, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_output = library::NumericTypeID::kF16, + library::NumericTypeID element_compute = library::NumericTypeID::kF32, + library::MathOperationID reduce_math_op = library::MathOperationID::kAdd, + library::EpilogueKind epilogue_math_op = library::EpilogueKind::kLinearCombination + ): + provider(provider), + element_workspace(element_workspace), + element_accumulator(element_accumulator), + element_output(element_output), + element_compute(element_compute), + reduce_math_op(reduce_math_op), + epilogue_math_op(epilogue_math_op) + { } + + inline + bool operator==(ReductionFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (element_workspace == rhs.element_workspace) && + (element_accumulator == rhs.element_accumulator) && + (element_output == rhs.element_output) && + (element_compute == rhs.element_compute) && + (reduce_math_op == rhs.reduce_math_op) && + (epilogue_math_op == rhs.epilogue_math_op); + } + + inline + bool operator!=(ReductionFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +struct ReductionFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ReductionFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.element_workspace)), 2) ^ + rotl(hash(int(key.element_accumulator)), 3) ^ + rotl(hash(int(key.element_output)), 4) ^ + rotl(hash(int(key.element_compute)), 5) ^ + rotl(hash(int(key.reduce_math_op)), 6) ^ + rotl(hash(int(key.epilogue_math_op)), 7); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const ReductionFunctionalKey& key) { + out << "{\n" + << "provider: " << library::to_string(key.provider) << std::endl + << "element_workspace : " << library::to_string(key.element_workspace) << std::endl + << "element_accumulator : " << library::to_string(key.element_accumulator) << std::endl + << "element_output : " << library::to_string(key.element_output) << std::endl + << "element_compute : " << library::to_string(key.element_compute) << std::endl + << "}"; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ReductionOperationFunctionalMap has NO preference key and a single instance per functional key +// i.e. only one tile size configuration per functional key +using ReductionOperationFunctionalMap = std::unordered_map< + ReductionFunctionalKey, + library::Operation const *, + ReductionFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Table of cutlass::library::Operation instances +class OperationTable { +public: + + /// Map of all operations of type kGemm + // provider (kCUTLASS) + GemmOperationFunctionalMap gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockwiseGemmOperationFunctionalMap blockwise_gemm_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv2d_operations; + + /// Map of all operations of type kConv3d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv3d_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS) + ReductionOperationFunctionalMap reduction_operations; + +public: + + void append(Manifest const &manifest); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h new file mode 100644 index 0000000000000000000000000000000000000000..9a757433f38fbf10d9a352e07c7f3084a99e4098 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h @@ -0,0 +1,68 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Singleton instance stores a Manifest and Operation table +class Singleton { +public: + + /// Manifest object + Manifest manifest; + + /// Operation table referencing the Manifest + OperationTable operation_table; + +public: + + Singleton(); + + static Singleton const &get(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8c4ff13ba543b4ec63997ba55e9278bfb357a6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + #pragma once + + ///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Layout type identifier +enum class LayoutTypeID { + kUnknown, + kColumnMajor, + kRowMajor, + kBlockScalingTensor, + kColumnMajorInterleavedK2, + kRowMajorInterleavedK2, + kColumnMajorInterleavedK4, + kRowMajorInterleavedK4, + kColumnMajorInterleavedK16, + kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, + kTensorNCHW, + kTensorNCDHW, + kTensorNHWC, + kTensorNDHWC, + kTensorNC32HW32, + kTensorC32RSK32, + kTensorNC64HW64, + kTensorC64RSK64, + kInvalid +}; + +/// Numeric data type +enum class NumericTypeID { + kUnknown, + kVoid, + kB1, + kU2, + kU4, + kU8, + kU16, + kU32, + kU64, + kS2, + kS4, + kS8, + kS16, + kS32, + kS64, + kFE4M3, + kFE5M2, + + kFE2M3, + kFE3M2, + kFE2M1, + kFUE8M0, + kFUE4M3, + kF8, + kF6, + kF4, + + kF16, + kBF16, + kTF32, + kF32, + kF64, + kCF16, + kCBF16, + kCF32, + kCTF32, + kCF64, + kCS2, + kCS4, + kCS8, + kCS16, + kCS32, + kCS64, + kCU2, + kCU4, + kCU8, + kCU16, + kCU32, + kCU64, + kInvalid +}; + +/// Enumerated type describing a transformation on a complex value. +enum class ComplexTransform { + kNone, + kConjugate, + kInvalid +}; + +/// Providers +enum class Provider { + kNone, + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kCUDNN, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating the kind of operation +enum class OperationKind { + kGemm, + kBlockScaledGemm, + kBlockwiseGemm, + kRankK, + kRank2K, + kTrmm, + kSymm, + kConv2d, + kConv3d, + kEqGemm, + kSparseGemm, + kReduction, + kGroupedGemm, + kInvalid +}; + +/// Enumeration indicating whether scalars are in host or device memory +enum class ScalarPointerMode { + kHost, + kDevice, + kInvalid +}; + +/// Describes how reductions are performed across threadblocks +enum class SplitKMode { + kNone, + kSerial, + kParallel, + kParallelSerial, + kInvalid +}; + +/// Indicates the classificaition of the math instruction +enum class OpcodeClassID { + kSimt, + kTensorOp, + kWmmaTensorOp, + kSparseTensorOp, + kBlockScaledOp, + kInvalid +}; + +enum class MathOperationID { + kAdd, + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddMixedInputUpcast, + kMultiplyAddFastBF16, + kMultiplyAddFastF16, + kMultiplyAddFastF32, + kMultiplyAddComplex, + kMultiplyAddComplexFastF32, + kMultiplyAddGaussianComplex, + kXorPopc, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating what kind of GEMM operation to perform +enum class GemmKind { + kGemm, + kBlockScaledGemm, + kSparse, + kUniversal, + kPlanarComplex, + kPlanarComplexArray, + kGrouped, + kInvalid +}; + +/// Enumeration indicating what kind of RankK update operation to perform +enum class RankKKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of TRMM operation to perform +enum class TrmmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of SYMM/HEMM operation to perform +enum class SymmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of Conv2d operation to perform +enum class ConvKind { + kUnknown, + kFprop, + kDgrad, + kWgrad, + kInvalid +}; + +enum class ConvModeID { + kCrossCorrelation, + kConvolution, + kInvalid +}; + +// Iterator algorithm enum in order of general performance-efficiency +enum class IteratorAlgorithmID { + kNone, + kAnalytic, + kOptimized, + kFixedChannels, + kFewChannels, + kInvalid +}; + + +enum class EpilogueKind { + kUnknown, + kConversion, + kLinearCombination, + kLinearCombinationClamp, + kLinearCombinationPlanarComplex, + kLinearCombinationRelu, + kLinearCombinationSigmoid, + kInvalid +}; + + +enum class RuntimeDatatype { + kStatic, + kE4M3, + kE5M2, + kE3M2, + kE2M3, + kE2M1, + + kInvalid +}; + + +enum class RasterOrder { + kAlongN, + kAlongM, + kHeuristic, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h new file mode 100644 index 0000000000000000000000000000000000000000..f537421751c1f2af3b95a2e1951006af441b28e0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h @@ -0,0 +1,281 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief Utilities accompanying the CUTLASS library for interacting with Library types. +*/ + +#ifndef CUTLASS_LIBRARY_UTIL_H +#define CUTLASS_LIBRARY_UTIL_H + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast from string +template T from_string(std::string const &); + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty = false); + +/// Parses a Provider enumerant from a string +template <> Provider from_string(std::string const &str); + +/// Converts a GemmKind enumerant to a string +char const *to_string(GemmKind type, bool pretty = false); + +/// Converts a RankKKind enumerant to a string +char const *to_string(RankKKind type, bool pretty = false); + +/// Converts a TrmmKind enumerant to a string +char const *to_string(TrmmKind type, bool pretty = false); + +/// Converts a SymmKind enumerant to a string +char const *to_string(SymmKind type, bool pretty = false); + +/// Converts a SideMode enumerant to a string +char const *to_string(SideMode type, bool pretty = false); + +/// Converts a FillMode enumerant to a string +char const *to_string(FillMode type, bool pretty = false); + +/// Converts a BlasMode enumerant to a string +char const *to_string(BlasMode type, bool pretty = false); + +/// Converts a DiagType enumerant to a string +char const *to_string(DiagType type, bool pretty = false); + +/// Converts a NumericType enumerant to a string +char const *to_string(OperationKind type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> OperationKind from_string(std::string const &str); + +/// Converts a NumericType enumerant to a string +char const *to_string(NumericTypeID type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> NumericTypeID from_string(std::string const &str); + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type); + +/// Returns true if the numeric type is a complex data type or false if real-valued. +bool is_complex_type(NumericTypeID type); + +/// Returns the real-valued type underlying a type (only different from 'type' if complex) +NumericTypeID get_real_type(NumericTypeID type); + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type); + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type); + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type); + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type); + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type); + +/// To string method for cutlass::Status +char const *to_string(Status status, bool pretty = false); + +/// Converts a LayoutTypeID enumerant to a string +char const *to_string(LayoutTypeID layout, bool pretty = false); + +/// Parses a LayoutType enumerant from a string +template <> LayoutTypeID from_string(std::string const &str); + +/// Returns the rank of a layout's stride base on the LayoutTypeID +int get_layout_stride_rank(LayoutTypeID layout_id); + +/// Converts a OpcodeClassID enumerant to a string +char const *to_string(OpcodeClassID type, bool pretty = false); + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const &str); + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty = false); + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str); + + +/// Converts a SplitKMode enumerant to a string +char const *to_string(SplitKMode split_k_mode, bool pretty = false); + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const &str); + +/// Converts a ConvModeID enumerant to a string +char const *to_string(ConvModeID type, bool pretty = false); + +/// Converts a ConvModeID enumerant from a string +template <> +ConvModeID from_string(std::string const &str); + +/// Converts a IteratorAlgorithmID enumerant to a string +char const *to_string(IteratorAlgorithmID type, bool pretty = false); + +/// Converts a IteratorAlgorithmID enumerant from a string +template <> +IteratorAlgorithmID from_string(std::string const &str); + +/// Converts a ConvKind enumerant to a string +char const *to_string(ConvKind type, bool pretty = false); + +/// Converts a ConvKind enumerant from a string +template <> +ConvKind from_string(std::string const &str); + + +/// Converts a RuntimeDatatype enumerant to a string +char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false); + +/// Convers a RuntimeDatatype enumerant from a string +template<> +cutlass::library::RuntimeDatatype from_string(std::string const &str); + + +/// Converts a RasterOrder enumerant to a string +char const *to_string(RasterOrder type, bool pretty = false); + +/// Convers a RasterOrder enumerant from a string +template<> +RasterOrder from_string(std::string const &str); + +/// Converts a bool to a string +char const *to_string(bool type, bool pretty = false); + +/// Convers a bool from a string +template<> +bool from_string(std::string const &str); + +/// Lexical cast from int64_t to string +std::string lexical_cast(int64_t int_value); + +/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. +bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. +std::string lexical_cast(std::vector &bytes, NumericTypeID type); + +/// Casts from a signed int64 to the destination type. Returns true if successful. +bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); + +/// Casts from an unsigned int64 to the destination type. Returns true if successful. +bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); + +/// Casts from a real value represented as a double to the destination type. Returns true if successful. +bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); + +NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << std::endl; \ + return Status::kInvalid; \ + } \ + } while (0) + +// RAII CUDA buffer container +class CudaBuffer { +public: + CudaBuffer() : size_(0), d_ptr_(nullptr) {} + + explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) { + cudaError_t err = cudaMalloc(&d_ptr_, size_); + if (err != cudaSuccess) { + throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err))); + } + } + + ~CudaBuffer() { + if (d_ptr_) { + cudaFree(d_ptr_); + } + } + + CudaBuffer(CudaBuffer const&) = delete; + CudaBuffer& operator=(CudaBuffer const&) = delete; + + CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) { + other.d_ptr_ = nullptr; + other.size_ = 0; + } + + CudaBuffer& operator=(CudaBuffer&& other) noexcept { + if (this != &other) { + if (d_ptr_) { + cudaFree(d_ptr_); + } + d_ptr_ = other.d_ptr_; + size_ = other.size_; + other.d_ptr_ = nullptr; + other.size_ = 0; + } + return *this; + } + + void* data() const noexcept { return d_ptr_; } + size_t size() const noexcept { return size_; } + +private: + size_t size_; + void* d_ptr_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c96b9a2212b42c191551ea70da3ac3baecbed487 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using Sm1xxBlkScaledConfig = typename CollectiveMainloop::Sm1xxBlkScaledConfig; + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB; + + +private: + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockScaledGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = LayoutTypeID::kRowMajor; + description_.SFA.alignment = 128; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = LayoutTypeID::kRowMajor; + description_.SFB.alignment = 128; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFVecSize = SFVecSize; + + description_.SFD = make_TensorDescription(128); + description_.EpilogueSFVecSize = SFD_VectorSize; + + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockScaledGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockScaledGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + + operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b); + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockScaledGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00347a993e29035e58401e69698267045b399f7d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp @@ -0,0 +1,429 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::ElementAccumulator; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::ElementAccumulator; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + +private: + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockwiseGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = size<0,1>(typename CollectiveMainloop::LayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + description_.SFA.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = size<0,1>(typename CollectiveMainloop::LayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + description_.SFB.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + description_.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + description_.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockwiseGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockwiseGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockwiseGemmArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockwiseGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockwiseGemmArguments const *arguments = + static_cast(arguments_ptr); + + if (arguments->sf_m_vec_size != description_.SFMVecSize && arguments->sf_m_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_n_vec_size != description_.SFNVecSize && arguments->sf_n_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_k_vec_size != description_.SFKVecSize && arguments->sf_k_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1a1584db92c4379e04c84a2658f79313b3eaad --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h @@ -0,0 +1,650 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/kernel/default_conv2d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv2dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv2dOperationBase(char const *name = "unknown_conv2d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv2d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + // TODO: Add split k mode Serial and parallel to convolutions + // description_.split_k_mode = Operator::kSplitK ? SplitKMode::kSerial : SplitKMode::kNone; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv2dOperation : public Conv2dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv2dOperation(char const *name = "unknown_conv2d_fprop") : Conv2dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// DirectConv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConv2dOperation : public Conv2dOperation { +public: + + using Operator = Operator_; + using Base = Conv2dOperation; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_reordered_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fe402c4494c27a882bf42f867a708e954ee87dc0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h @@ -0,0 +1,389 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/kernel/default_conv3d_dgrad.h" +#include "cutlass/conv/kernel/default_conv3d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv3dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv3dOperationBase(char const *name = "unknown_conv3d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv3d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv3dOperation : public Conv3dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv3dOperation(char const *name = "unknown_conv3d_fprop") : Conv3dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv3dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv3dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv3dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv3dOperation::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << ", " + << operator_args.ref_A.stride(3) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << ", " + << operator_args.ref_B.stride(3) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << ", " + << operator_args.ref_C.stride(3) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << ", " + << operator_args.ref_D.stride(3) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86c1513e9c934c22e281cf37e1c5e7783e23d305 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp @@ -0,0 +1,980 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include +#include +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) +#include +#endif + +namespace cutlass::library { + +namespace detail { + +template +constexpr cute::array +vector_to_array_strides_helper(const std::vector& v, + std::index_sequence) +{ + return {v[(sizeof...(Indices) - 1u) - Indices]..., ValueType(1)}; +} + +template +cute::array +vector_to_array_strides(const std::vector& v, std::integral_constant) +{ + static_assert(Size != 0); + CUTLASS_ASSERT(v.size() + 1u == Size); + return vector_to_array_strides_helper(v, std::make_index_sequence{}); +} + +template +constexpr cute::array +coord_to_array_strides_helper( + const ::cutlass::Coord coord, + std::index_sequence) +{ + return {int64_t(coord[(sizeof...(Indices) - 1u) - Indices])..., int64_t(1)}; +} + +template +cute::array +coord_to_array_strides(const ::cutlass::Coord& coord) +{ + static_assert(Rank >= 0); + return coord_to_array_strides_helper(coord, std::make_index_sequence{}); +} + +} // namespace detail + +// Tells the profiler about CUTLASS 3's 2-D and 3-D convolutions. +// For CUTLASS 2's 2-D convolutions, see Conv2dOperation. +// For CUTLASS 2's 3-D convolutions, see Conv3dOperation. +template +class ConvOperation3x : public Operation { +public: + using Operator = Operator_; + + static_assert(Operator::NumSpatialDimensions == 2 || + Operator::NumSpatialDimensions == 3, + "The profiler currently only supports convolutions with 2 or 3 spatial dimensions."); + using LayoutA = cute::conditional_t + >; + using LayoutB = LayoutA; + using LayoutC = LayoutA; + + using ElementA = typename Operator::ElementA; + using ElementB = typename Operator::ElementB; + using ElementC = typename Operator::ElementC; + using ElementD = typename Operator::ElementD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + ConvOperation3x(const char* name = "unknown_cutlass_3_conv") { + // Initialize OperationDescription (the base class) + description_.name = name; + description_.provider = Provider::kCUTLASS; + + if constexpr (Operator::NumSpatialDimensions == 2) { + description_.kind = OperationKind::kConv2d; + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + description_.kind = OperationKind::kConv3d; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationID::kMultiplyAdd; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + // Initialize ConvDescription (the subclass) + + // kConvDim does not exist in Operator for CUTLASS 3 convolutions. + // For CUTLASS 2 convolutions, it is the number of spatial dimensions. + description_.conv_dim = Operator::NumSpatialDimensions; + description_.conv_kind = ConvKindMap::kId; + + description_.iterator_algorithm = {}; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + } + + ~ConvOperation3x() override = default; + + OperationDescription const& description() const override { + return static_cast(description_); + } + +private: + Status update_operator_arguments_from_configuration_2d_or_3d( + typename Operator::Arguments& out_args, + void const* configuration) const { + Status status = Status::kInvalid; + + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + // tools/library/include/cutlass/library/library.h + // defines Conv2dConfiguration. + // tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h + // uses Conv2dConfiguration. + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + return status; + } + +public: + Status can_implement( + void const* configuration, + void const* arguments) const override { + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp accesses "configuration" as + // GemmUniversalConfiguration (which lives in + // tools/library/include/cutlass/library/library.h) and + // "arguments" as GemmUniversalArguments (which lives in + // tools/library/include/cutlass/library/library.h). + // Those things don't apply to convolutions. + // Despite the existence of ConvUniversal, there's no + // corresponding "ConvUniversalConfiguration" or + // "ConvUniversalArguments." + + CUTLASS_ASSERT(configuration != nullptr); + CUTLASS_ASSERT(arguments != nullptr); + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_configuration_2d_or_3d failed"); + return status; + } + + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_arguments failed"); + return status; + } + + return Operator::can_implement(out_args); + } + + uint64_t get_host_workspace_size(void const* /* configuration */) const override { + return sizeof(Operator); + } + + uint64_t get_device_workspace_size( + void const* configuration, + void const* arguments = nullptr) const override + { + // This presumes that at least one of configuration or arguments is nonnull. + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp has get_device_workspace_size return 0 on + // error. It's not clear that this is what we want -- perhaps we + // should return something like expected? -- but + // it's the only option that preserves the current interface. + constexpr uint64_t error_indication = 0; + + typename Operator::Arguments out_args{}; + if (configuration != nullptr) { + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + return error_indication; + } + } + if (arguments != nullptr) { + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return error_indication; + } + } + + if (status == Status::kSuccess) { + return static_cast(Operator::get_workspace_size(out_args)); + } + else { + return error_indication; + } + } + + Status initialize( + void const* configuration, + void* host_workspace, + void* /* device_workspace */ = nullptr, + cudaStream_t stream = nullptr) const override + { + Status status = Status::kInvalid; + + if (configuration == nullptr) { + CUTLASS_TRACE_HOST("Input configuration is null."); + return Status::kInvalid; + } + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + // Any kind of failure invalidates the last successful configuration. + clear_last_successful_config(); + return status; + } + else { + set_last_successful_config(configuration); + } + + if (host_workspace == nullptr) { + CUTLASS_TRACE_HOST("host_workspace is null."); + return Status::kInvalid; + } + (void) new (host_workspace) Operator; + return status; + + // CUTLASS 2 convolutions call the Operator's initialize function + // here, like this. + // + //return op->initialize(args, device_workspace, stream); + // + // CUTLASS 3 convolutions (ConvUniversal), like CUTLASS 3 Gemms + // (GemmUniversal), lack an "initialize" member function. + } + + Status run( + void const* arguments, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override + { + auto status = Status::kInvalid; + + // The Operator doesn't appear to save the last configuration (it + // doesn't have a way to do that, since it lacks an initialize() + // member function), so we have to use the stored configuration + // from the last successful initialize() call (if any). + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_stored_configuration(out_args); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("Updating from previous successful configuration failed."); + return status; + } + + if (arguments == nullptr) { + CUTLASS_TRACE_HOST("Input argument 'arguments' is null."); + return Status::kInvalid; + } + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return status; + } + + auto* op = reinterpret_cast(host_workspace); + return op->run(out_args, device_workspace, stream, nullptr, in_args_ptr->use_pdl); + } + +private: + ConvDescription description_; + // Result of initialize() calling + // update_operator_arguments_from_configuration() successfully. + // This is needed because run() doesn't take a configuration, just + // arguments, and the kernel doesn't appear to save the + // configuration from the last initialize() call. + // + // Unfortunately, this must be declared mutable, because it must be + // set in initialize(), and initialize() is inherited as const. + mutable std::variant< + std::monostate, + Conv2dConfiguration, + Conv3dConfiguration> last_successful_config_{std::monostate{}}; + + // Clear the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void clear_last_successful_config() const { + last_successful_config_ = std::monostate{}; + } + + // Set the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void set_last_successful_config(void const* configuration) const { + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + } + + // Whether a configuration from a successful initialize() call exists. + bool last_successful_config_exists() const { + return not std::holds_alternative(last_successful_config_); + } + + // Visitor for update_operator_arguments_from_stored_configuration. + struct ConfigurationVisitor { + typename Operator::Arguments& out_args; + + Status operator() (std::monostate const&) const { + CUTLASS_TRACE_HOST("No successful previous configuration exists. " + "One cause is calling run() before a successful initialize() call."); + return Status::kInvalid; + } + Status operator() (Conv2dConfiguration const& conf2d) const { + return update_operator_arguments_from_configuration(out_args, conf2d); + } + Status operator() (Conv3dConfiguration const& conf3d) const { + return update_operator_arguments_from_configuration(out_args, conf3d); + } + }; + + // Like update_operator_arguments_from_configuration, but on the + // stored configuration from the last successful initialize() call, + // if any. If there was no last successful initialize() call, + // then return Status::kInvalid. + // + // Unfortunately, this must be declared const, because run() is. + Status update_operator_arguments_from_stored_configuration( + typename Operator::Arguments& out_args) const + { + return std::visit(ConfigurationVisitor{out_args}, last_successful_config_); + } + + template + struct UpdateFusionArgs { + static Status update_( + FusionArgs const&, + ConvArguments const&) + { + // For custom EVT, it is the user's responsibility to ensure + // that alpha and beta are updated appropriately. + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_( + FusionArgs& fusion_args, + ConvArguments const& arguments) + { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv2dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv2dConfiguration)\n"); +#endif + using detail::vector_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 2) { + CUTLASS_TRACE_HOST("You can only use Conv2dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 2."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv2dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + constexpr auto the_stride_size = std::integral_constant{}; + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (auto const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < stride.size(); ++k) { + std::cerr << stride[k]; + if (k + 1u < stride.size()) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(config.stride_a, "config.stride_a"); + print_stride(config.stride_b, "config.stride_b"); + print_stride(config.stride_c, "config.stride_c"); +#endif + + // Conv2dConfiguration stores the strides as std::vector, + // so the code needs to check the run-time vector lengths. + if (config.stride_a.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_a.size() + 1u = " + << (config.stride_a.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_b.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_b.size() + 1u = " + << (config.stride_b.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_c.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_c.size() + 1u = " + << (config.stride_c.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); + const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); + const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); + + // cutlass::library::Conv2dConfiguration has no member stride_d. + // The code below imitates the testbed, + // which just sets D's strides to C's strides. + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // This means that stride_act isn't always config.stride_A, + // depending on Fprop / Dgrad / Wgrad. The code here "undoes" + // the logic in Conv2dWorkspace::set_stride_vector so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_h, pad_w}, + /* upper_padding = */ {pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv3dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv3dConfiguration)\n"); +#endif + using detail::coord_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 3) { + CUTLASS_TRACE_HOST("You can only use Conv3dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 3."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv3dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int D = config.problem_size.D; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int T = config.problem_size.T; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_d = config.problem_size.pad_d; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_d = config.problem_size.stride_d; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_d = config.problem_size.dilation_d; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + using Stride = cutlass::layout::TensorNDHWC::Stride; + static_assert(std::is_same_v>); + + const cutlass::library::ConvKind conv_kind = [] () { + constexpr cutlass::conv::Operator op = Operator::DispatchPolicy::ConvOp; + if constexpr (op == cutlass::conv::Operator::kFprop) { + return library::ConvKind::kFprop; + } + else if constexpr (op == cutlass::conv::Operator::kDgrad) { + return library::ConvKind::kDgrad; + } + else /* if constexpr (op == cutlass::conv::Operator::kWgrad) */ { + return library::ConvKind::kWgrad; + } + } (); + const Stride input_stride_a = config.layout_a(conv_kind).stride(); + const Stride input_stride_b = config.layout_b(conv_kind).stride(); + const Stride input_stride_c = config.layout_c(conv_kind).stride(); + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (Stride const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < Stride::kRank; ++k) { + std::cerr << stride[static_cast(k)]; + if (k + 1u < Stride::kRank) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(input_stride_a, "input_stride_a"); + print_stride(input_stride_b, "input_stride_b"); + print_stride(input_stride_c, "input_stride_c"); +#endif + // Conv3dConfiguration stores the strides as Coord (with + // compile-time size), so there's no need to check sizes here + // (unlike Conv2dConfiguration, which stores strides as + // std::vector). + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = coord_to_array_strides(input_stride_a); + const TensorStride stride_B = coord_to_array_strides(input_stride_b); + const TensorStride stride_C = coord_to_array_strides(input_stride_c); + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // Conv3dConfiguration differs a bit from Conv2dConfiguration, + // but the idea is the same: the "input_stride_a" from config + // depends on conv_kind (Fprop, Dgrad, or Wgrad), so stride_act + // isn't always input_stride_a. Analogously, stride_flt isn't + // always input_stride_b. The code here "undoes" the logic in + // config.layout_a(conv_kind) and config.layout_b(conv_kind) + // (analogous to Conv2dWorkspace::set_stride_vector) so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, D, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, T, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_d, pad_h, pad_w}, + /* upper_padding = */ {pad_d, pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_d, dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + Status update_operator_arguments_from_arguments( + typename Operator::Arguments& out_args, + ConvArguments const& in_args) const + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperation3x::update_operator_arguments_from_arguments\n"); +#endif + auto status = UpdateFusionArgs::update_( + out_args.epilogue.thread, in_args); + if (status != Status::kSuccess) { + return status; + } + + out_args.mainloop.ptr_A = reinterpret_cast(in_args.A); + out_args.mainloop.ptr_B = reinterpret_cast(in_args.B); + + out_args.epilogue.ptr_C = reinterpret_cast(in_args.C); + out_args.epilogue.ptr_D = reinterpret_cast(in_args.D); + + return Status::kSuccess; + } +}; + +} // namespace cutlass::library diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..880cb4bf34b1f3d946e1dc86b80806309bb2b3c1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h @@ -0,0 +1,1408 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/gemm/device/gemm_batched.h" +#include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmOperationBase(char const *name = "unknown_gemm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kGemm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::GemmKernel::WarpCount::kM, + Operator::GemmKernel::WarpCount::kN, + Operator::GemmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kGemm; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + + operator_args.split_k_slices = configuration->split_k_slices; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmSparseOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementE = typename Operator::ElementE; + using LayoutE = typename Operator::LayoutE; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.kind = OperationKind::kSparseGemm; + this->description_.gemm_kind = GemmKind::kSparse; + this->description_.E = make_TensorDescription(Operator::kAlignmentE); + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SparseGemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + operator_args.ref_E = {nullptr, configuration->lde}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SparseGemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + operator_args.ref_E.reset(static_cast(arguments->E)); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SparseGemmConfiguration const *configuration = + static_cast(configuration_ptr); + + SparseGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmUniversalOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmUniversalConfiguration const *configuration) { + + operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = (configuration->lda); + operator_args.ldb = (configuration->ldb); + operator_args.ldc = (configuration->ldc); + operator_args.ldd = (configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplex; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.batch_stride_A = arguments->batch_stride_A_real; + operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; + operator_args.batch_stride_B = arguments->batch_stride_B_real; + operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; + operator_args.batch_stride_C = arguments->batch_stride_C_real; + operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; + operator_args.batch_stride_D = arguments->batch_stride_D_real; + operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplexArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.ptr_M = arguments->M; + operator_args.ptr_N = arguments->N; + operator_args.ptr_K = arguments->K; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmGroupedOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmGroupedOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.provider = Provider::kCUTLASS; + this->threadblock_count = Operator::sufficient(); + + this->description_.gemm = GemmOperationBase::description_; + this->description_.gemm.gemm_kind = GemmKind::kGrouped; + this->description_.tile_description = this->description_.gemm.tile_description; + } + + /// Returns the description of the GroupedGEMM operation + virtual OperationDescription const & description() const override final { + return description_; + } + + +private: + int threadblock_count; + GroupedGemmDescription description_; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + Status construct_arguments_( + OperatorArguments &op_args, + GemmGroupedConfiguration const *config) const { + + op_args.problem_count = config->problem_count; + op_args.threadblock_count = threadblock_count; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments &op_args, + GemmGroupedArguments const *arguments) const { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { + + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + op_args.threadblock_count = threadblock_count; + op_args.problem_count = arguments->problem_count; + op_args.problem_sizes = arguments->problem_sizes; + + op_args.ptr_A = static_cast(arguments->ptr_A); + op_args.ptr_B = static_cast(arguments->ptr_B); + op_args.ptr_C = static_cast(arguments->ptr_C); + op_args.ptr_D = static_cast(arguments->ptr_D); + + op_args.lda = arguments->lda; + op_args.ldb = arguments->ldb; + op_args.ldc = arguments->ldc; + op_args.ldd = arguments->ldd; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmGroupedConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmGroupedArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c1d17943f11fe8126b3070c3fcead5598e2d207 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp @@ -0,0 +1,714 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation3xBase : public Operation { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + +protected: + GemmDescription description_; + +public: + + /// Constructor + GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = gemm_kind_; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + GemmDescription const& get_gemm_description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + +public: + + /// Constructor + GemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + dim3 cluster_dims( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + } + } + +private: + int max_active_clusters{}; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + template class Policy, int Stages, class ClusterShape, class KernelSchedule> + static constexpr bool is_sm90_mixed_dtype_mainloop_(Policy policy) { + return (cute::is_same_v, + cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>); + } + + template + static constexpr bool is_sm90_mixed_dtype_mainloop_(DispatchPolicy) { + return false; + } + + template < + typename ElementWide, + typename ElementNarrow, + typename ElementScaleMainloop, + class ActualStrideAB, + Sm90MixedInputWiderOperand wider_operand, + bool is_n4w8, + typename ElementScale, + typename ElementZero, + class Layout_SZ> + static void dequantize_encode_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + cudaStream_t stream, + const int &problem_mn, + const int &problem_k, + const int &options_l, + const int &options_g, + ElementScale *ptr_S, + ElementZero *ptr_Z, + const size_t &SZ_size, + Layout_SZ layout_SZ + ) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + auto *ptr_dequantized_AB = static_cast(arguments->dequantized_AB); + const ElementNarrow *ptr_AB = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + ptr_AB = static_cast(arguments->B); + } + else { + ptr_AB = static_cast(arguments->A); + } + dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream); + if constexpr(is_n4w8) { + size_t AB_size = cute::size(layout_AB); + cutlass::int4b_t *encoded_AB = static_cast(arguments->encoded_AB); + unified_encode_int4b(ptr_AB, encoded_AB, AB_size); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = static_cast(encoded_AB); + } + else { + operator_args.mainloop.ptr_A = static_cast(encoded_AB); + } + ElementScaleMainloop *ptr_packed_Scale = static_cast(arguments->packed_Scale); + pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size); + } + } + + template < + typename ElementAB, + class ActualStrideAB, + class LayoutAB_Reordered, + class LayoutAtomQuant, + Sm90MixedInputWiderOperand wider_operand> + static void handle_shuffle_tensor_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + const int &problem_mn, + const int &problem_k, + const int &options_l) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.dB = layout_AB_reordered; + } + else { + operator_args.mainloop.dA = layout_AB_reordered; + } + if (arguments->generate_dequantized_AB) { + size_t AB_size = cute::size(layout_AB); + ElementAB *AB_reordered = cutlass::device_memory::allocate(AB_size); + const ElementAB *AB_src = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + AB_src = static_cast(operator_args.mainloop.ptr_B); + } + else { + AB_src = static_cast(operator_args.mainloop.ptr_A); + } + reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered); + ElementAB *AB_dst = static_cast(arguments->encoded_AB); + cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size); + cutlass::device_memory::free(AB_reordered); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = AB_dst; + } + else { + operator_args.mainloop.ptr_A = AB_dst; + } + } + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments& operator_args, + GemmUniversalArguments const* arguments, + cudaStream_t stream = nullptr) const { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + // TODO: type erase Arguments structure in 3.0 GEMM + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + // Stride{A,B} is a Layout if and only if: + // (1) This is a mixed dtype kernel, and + // (2) This mixed dtype kernel is using shuffling, and + // (3) sizeof(narrow_type) == 4 or 8 bits, and + // (4) sizeof(wide_type) == 16 bits. + // If A/B has the narrow data type, Stride{A/B} will be a Layout + constexpr bool is_StrideA_Layout = cute::is_layout::value; + constexpr bool is_StrideB_Layout = cute::is_layout::value; + static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout"); + if constexpr(!is_StrideA_Layout) { + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + } + if constexpr(!is_StrideB_Layout) { + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + } + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy; + if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{})) { + const int problem_m = arguments->problem_size.m(); + const int problem_n = arguments->problem_size.n(); + const int problem_k = arguments->problem_size.k(); + const int options_l = arguments->batch_count; + + constexpr Sm90MixedInputWiderOperand wider_operand = + (cutlass::sizeof_bits::value > cutlass::sizeof_bits::value) ? + Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B; + using ElementWide = std::conditional_t; + using ElementNarrow = std::conditional_t; + + constexpr bool has_scale = !std::is_same_v; + constexpr bool has_zero = !std::is_same_v; + + const int options_g = problem_k; + const int scale_k = (problem_k + options_g - 1) / options_g; + + constexpr bool is_A4B8 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_A8B4 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4; + + // If this is a convert-only kernel, we still need to generate dequantized A or B for verification, + // and in this case ElementScale is the same as ElementWide + // In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element + using DummyElementScaleMainloop = std::conditional_t< + is_int4_x_fp8, + typename cutlass::Array, + ElementWide + >; + using ElementScaleMainloop = std::conditional_t< + has_scale, + typename CollectiveMainloop::ElementScale, + DummyElementScaleMainloop + >; + using ElementScale = std::conditional_t< + has_scale, + typename UnderlyingElement::type, + ElementWide + >; + using StrideScale = typename CollectiveMainloop::StrideScale; + // In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S + using ElementZero = std::conditional_t< + has_zero, + typename CollectiveMainloop::ElementZero, + ElementScale + >; + const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m; + const size_t SZ_size = static_cast(SZ_1st_dim * scale_k * options_l); + auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l); + ElementScale *ptr_S = static_cast(arguments->Scale); + ElementZero *ptr_Z = static_cast(arguments->Zero); + + // 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled + if (arguments->generate_scale_and_zero) { + float scale_min = 1.0f, scale_max = 1.0f; + if constexpr(has_scale) { + const float elt_max_f = float(cutlass::platform::numeric_limits::max()); + // Need to fix max_dequant_val and min_dequant_val? + const float max_dequant_val = elt_max_f * 0.25f; + const float min_dequant_val = 0.5f; + scale_max = max_dequant_val / elt_max_f; + scale_min = min_dequant_val / elt_max_f; + } + uint64_t seed = 2023; + cutlass::reference::device::BlockFillRandomUniform( + ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min)); + + // In ScaleOnly mode, set Z as zero for generating dequantized A or B + const float zero_max = has_zero ? 2.0f : 0.0f; + const float zero_min = has_zero ? -2.0f : 0.0f; + cutlass::reference::device::BlockFillRandomUniform( + ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min)); + } // End of "if (arguments->generate_scale_and_zero)" + + // 2. Generate the dequantized A or B for verification + if (arguments->generate_dequantized_AB) { + StrideScale stride_SZ = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + if constexpr(is_StrideB_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of B later + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideB = typename CollectiveMainloop::StrideB; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } + else { + if constexpr(is_StrideA_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of A later + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideA = typename CollectiveMainloop::StrideA; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)" + } // End of "if (arguments->generate_dequantized_AB)" + + // 3. Put Scale and Zero in mainloop + if constexpr(has_scale) { + if constexpr(is_int4_x_fp8) { + operator_args.mainloop.ptr_S = static_cast(arguments->packed_Scale); + } + else { + operator_args.mainloop.ptr_S = static_cast(arguments->Scale); + } + operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + operator_args.mainloop.group_size = options_g; + if constexpr(has_zero) { + operator_args.mainloop.ptr_Z = static_cast(arguments->Zero); + } + } // End of "if constexpr(has_scale)" + + // Handle the shuffling + using ValueShuffle = std::conditional_t< + cutlass::sizeof_bits::value == 4, + cute::Layout, cute::Stride>, + cute::Layout, cute::Stride> + >; + constexpr int NumShuffleAtoms = 1; + using MmaAtomShape = cute::Layout>>; + using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout and stride of A/B later + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) { + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + using LayoutB_Reordered = typename CollectiveMainloop::StrideB; + handle_shuffle_tensor_( + operator_args, arguments, problem_n, problem_k, options_l); + } + if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) { + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + using LayoutA_Reordered = typename CollectiveMainloop::StrideA; + handle_shuffle_tensor_( + operator_args, arguments, problem_m, problem_k, options_l); + } + } // End of "if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{}))" + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + operator_args.hw_info.max_active_clusters = max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + [[maybe_unused]] void const *configuration_ptr, void const *arguments_ptr) const override { + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + OperatorArguments args; + + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + Status can_impl = Operator::can_implement(args); + + //return Operator::can_implement(args); + return can_impl; + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr), stream); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91f618d4fab74a6d43e2d82c572d215d5bea5a1c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp @@ -0,0 +1,873 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all grouped GEMM operations in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "gemm_operation_3x.hpp" +#include "library_internal.h" + +namespace cutlass::library { + +template +class GroupedGemmOperation3xBase : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + GroupedGemmOperation3xBase(char const* name = "unknown_gemm") + : GemmOperation3xBase(name, GemmKind::kGrouped) { + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.name = name; + this->description_.provider = Provider::kCUTLASS; + + this->description_.gemm = GemmOperation3xBase::description_; + this->description_.tile_description = this->description_.gemm.tile_description; + }; + +public: + mutable CudaBuffer strideA_device; + mutable CudaBuffer strideB_device; + mutable CudaBuffer strideC_device; + mutable CudaBuffer strideD_device; + + /// Returns the description of the GEMM operation + virtual OperationDescription const& description() const override final { return description_; } + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const* configuration) const override final { + return sizeof(Operator); + } + +protected: + library::GroupedGemmDescription description_; + + Status initialize_strides(GemmGroupedConfiguration const& config) const { + auto const num_groups = config.problem_count; + this->strideA_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups); + this->strideB_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups); + this->strideC_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups); + this->strideD_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups); + + std::vector strideA_host(num_groups); + std::vector strideB_host(num_groups); + std::vector strideC_host(num_groups); + std::vector strideD_host(num_groups); + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + strideA_host[group_idx] = + cute::make_int_tuple_from( + config.lda[group_idx]); + strideB_host[group_idx] = + cute::make_int_tuple_from( + config.ldb[group_idx]); + strideC_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + strideD_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + } + CUDA_CHECK(cudaMemcpy( + this->strideA_device.data(), + strideA_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideB_device.data(), + strideB_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideC_device.data(), + strideC_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideD_device.data(), + strideD_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups, + cudaMemcpyHostToDevice)); + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_base( + OperatorArguments& operator_args, + GemmGroupedArguments const& arguments) const { + operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped; + operator_args.problem_shape = { + arguments.problem_count, + arguments.problem_sizes_3x, + arguments.pointer_mode == ScalarPointerMode::kHost ? arguments.problem_sizes_3x_host + : nullptr}; + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE5M2) { + return cute::UMMA::MXF8F6F4Format::E5M2; + } + else if (type == RuntimeDatatype::kE4M3) { + return cute::UMMA::MXF8F6F4Format::E4M3; + } + else if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } + else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } + else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e4m3." << std::endl; + #endif + return cute::UMMA::MXF8F6F4Format::E4M3; + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e2m1." << std::endl; + #endif + return cute::UMMA::MXF4Format::E2M1; + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + operator_args.mainloop.runtime_data_type_a = mapping(arguments.runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments.runtime_input_datatype_b); + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + } + operator_args.epilogue.ptr_C = static_cast(arguments.ptr_C); + operator_args.epilogue.ptr_D = static_cast(arguments.ptr_D); + + operator_args.mainloop.dA = + static_cast(this->strideA_device.data()); + operator_args.mainloop.dB = + static_cast(this->strideB_device.data()); + operator_args.epilogue.dC = + static_cast(this->strideC_device.data()); + operator_args.epilogue.dD = + static_cast(this->strideD_device.data()); + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments.sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + operator_args.hw_info.max_active_clusters = arguments.max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments.swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments.raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = + dim3(arguments.cluster_shape.m(), arguments.cluster_shape.n(), arguments.cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments.cluster_shape_fallback.m(), + arguments.cluster_shape_fallback.n(), + arguments.cluster_shape_fallback.k()); + } + return Status::kSuccess; + } + + template + static Status update_fusion_args(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } +}; + +/// **** CAUTION **** +/// Unlike other operations, initialize() must be called when +/// certain arguments change. See initialize() for details. +template +class GroupedGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + +public: + GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) {} + + ~GroupedGemmUniversal3xOperation() override = default; + +private: + int max_active_clusters{}; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + + /// Constructs the arguments structure given the configuration and arguments + Status + update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const { + + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + status = this->update_arguments_base(operator_args, *arguments); + return status; + } + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GemmGroupedArguments const* arguments = static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + Operator* op = new (host_workspace) Operator; + + auto const& config = *static_cast(configuration_ptr); + return this->initialize_strides(config); + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl); + return status; + } + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + Status initialize_with_arguments(void* arguments_ptr) const override { + if constexpr (Operator::ArchTag::kMinComputeCapability < 90) { + return Status::kSuccess; + } + + GemmGroupedArguments* args = static_cast(arguments_ptr); + + dim3 cluster_dims; + if constexpr (cute::is_static_v) { + cluster_dims = dim3( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{}) + ); + } + else { + cluster_dims = dim3( + args->cluster_shape.m(), + args->cluster_shape.n(), + args->cluster_shape.k() + ); + } + + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + + if (args->max_active_clusters == 0) { + std::cerr << "Max Active Clusters could not be queried. " + << "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n"; + } + + return Status::kSuccess; + } +}; + +template +class GroupedBlockScaledGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + GroupedBlockScaledGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription block_scaled_desc{}; + block_scaled_desc.kind = OperationKind::kBlockScaledGemm; + block_scaled_desc.SFA.element = NumericTypeMap::kId; + block_scaled_desc.SFA.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFA.alignment = 128; + block_scaled_desc.SFA.log_extent_range = 32; + block_scaled_desc.SFA.log_stride_range = 32; + + block_scaled_desc.SFB.element = NumericTypeMap::kId; + block_scaled_desc.SFB.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFB.alignment = 128; + block_scaled_desc.SFB.log_extent_range = 32; + block_scaled_desc.SFB.log_stride_range = 32; + + block_scaled_desc.SFMVecSize = 1; + block_scaled_desc.SFNVecSize = 1; + block_scaled_desc.SFKVecSize = SFVecSize; + + block_scaled_desc.SFD = make_TensorDescription(128); + block_scaled_desc.EpilogueSFVecSize = SFD_VectorSize; + + this->description_.block_scales = block_scaled_desc; + } + + ~GroupedBlockScaledGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockScaledArguments const& arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockScaledArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockScaledArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + +template +class GroupedBlockwiseGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::ElementAccumulator; + using ElementSFB = typename Operator::ElementAccumulator; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + GroupedBlockwiseGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription blockwise_desc{}; + blockwise_desc.kind = OperationKind::kBlockwiseGemm; + blockwise_desc.SFA.element = NumericTypeMap::kId; + blockwise_desc.SFA.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + blockwise_desc.SFA.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFA.log_extent_range = 32; + blockwise_desc.SFA.log_stride_range = 32; + + blockwise_desc.SFB.element = NumericTypeMap::kId; + blockwise_desc.SFB.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + blockwise_desc.SFB.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFB.log_extent_range = 32; + blockwise_desc.SFB.log_stride_range = 32; + + blockwise_desc.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + blockwise_desc.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + blockwise_desc.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + blockwise_desc.EpilogueSFVecSize = 0; + + this->description_.block_scales = blockwise_desc; + } + + ~GroupedBlockwiseGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockwiseArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockwiseArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockwiseArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + + +} // namespace cutlass::library diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..e8bd77397f3b85cce2da2a7a8e447ab6ccb48aea --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/arch_mappings.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct NumericTypeMap; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kVoid; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kB1; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE4M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE5M2; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE3M2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M1; +}; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE8M0; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE4M3; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF64; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF16; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF32; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kBF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kTF32; +}; + + + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF6; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF4; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kInvalid; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAdd; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddMixedInputUpcast; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kXorPopc; +}; + + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct LayoutMap; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kRowMajor; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct OpcodeClassMap; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSimt; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kTensorOp; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ComplexTransformMap; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; +}; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ConvModeMap; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kCrossCorrelation; +}; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kConvolution; +}; + + +template struct ConvKindMap; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kFprop; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kDgrad; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kWgrad; +}; + + +template struct IteratorAlgorithmMap; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFixedChannels; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFewChannels; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +TensorDescription make_TensorDescription(int alignment = 1) { + TensorDescription desc; + + desc.element = NumericTypeMap::kId; + desc.layout = LayoutMap::kId; + desc.alignment = alignment; + desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; + desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; + + return desc; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..76d8d0dfdb1aa6ed0324b9d6299b06ebf3f436d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank 2K operation kinds (Syr2k, Her2k) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_2k.h" +#include "cutlass/gemm/kernel/default_rank_2k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + Rank2KOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRank2K; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::Rank2Kkernel::WarpCount::kM, + Operator::Rank2Kkernel::WarpCount::kN, + Operator::Rank2Kkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperation : public Rank2KOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + Rank2KOperation(char const *name = "unknown_rank_2k"): + Rank2KOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Rank2KOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..021f7f03fcc4449bdc2ef2c97e29fe0fead09a64 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank K operation kinds (Syrk, Herk) + in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_k.h" +#include "cutlass/gemm/kernel/default_rank_k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + RankKOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRankK; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::RankKkernel::WarpCount::kM, + Operator::RankKkernel::WarpCount::kN, + Operator::RankKkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentA); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperation : public RankKOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + RankKOperation(char const *name = "unknown_rank_k"): + RankKOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->lda); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..6e948540e3f29dceace42b5e8ef3f91118c01b37 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for reduction operation in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/reduction/device/reduce_split_k.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReductionOperation : public Operation { +public: + using Operator = Operator_; + + using ElementWorkspace = typename Operator::ElementWorkspace; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementOutput = typename Operator::ElementOutput; + + using ElementCompute = typename Operator::OutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ReductionDescription description_; + +public: + + /// Constructor + ReductionOperation(char const *name = "unknown_reduction") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kReduction; + + description_.tile_description.threadblock_shape = make_Coord(Operator::Shape::kRow, Operator::Shape::kColumn, 1); + + description_.tile_description.math_instruction.instruction_shape = make_Coord(1, 1, 1); + description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; + description_.tile_description.math_instruction.opcode_class = OpcodeClassID::kSimt; + description_.tile_description.math_instruction.math_operation = MathOperationID::kAdd; + + description_.tile_description.minimum_compute_capability = 50; + description_.tile_description.maximum_compute_capability = 1024; + + description_.element_workspace = NumericTypeMap::kId; + description_.element_output = NumericTypeMap::kId; + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the Reduction operation + virtual OperationDescription const & description() const { + return description_; + } + + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + ReductionConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.partitions = configuration->partitions; + operator_args.partition_stride = configuration->partition_stride; + + operator_args.workspace = {nullptr, int(configuration->ldw)}; + operator_args.source = {nullptr, int(configuration->lds)}; + operator_args.destination = {nullptr, int(configuration->ldd)}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ReductionArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::OutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::OutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.workspace.reset(static_cast(const_cast(arguments->workspace))); + operator_args.source.reset(static_cast(const_cast(arguments->source))); + operator_args.destination.reset(static_cast(const_cast(arguments->destination))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + ReductionConfiguration const *configuration = + static_cast(configuration_ptr); + + ReductionArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Reduction" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run library::Reduction" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Reduction::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Reduction::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " partitions: " + << operator_args.partitions << std::endl + << " partition_stride: " + << operator_args.partition_stride << std::endl + << " epilogue (alpha, beta): " + << operator_args.output.alpha << ", " + << operator_args.output.beta << std::endl + << " workspace (ptr, stride): " + << operator_args.workspace.data() << ", " + << operator_args.workspace.stride(0) << std::endl + << " source (ptr, stride): " + << operator_args.source.data() << ", " + << operator_args.source.stride(0) << std::endl + << " destination (ptr, stride): " + << operator_args.destination.data() << ", " + << operator_args.destination.stride(0) << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..769da1c8515877536fd9b9fd72c836fd43ebd5d8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +namespace detail { +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ElementSFD_ = void, + typename LayoutSFD_ = LayoutC_, + int SFVecSize_ = 32, + int EpilogueSFVecSize_ = 0, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockScaledGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementSFD = ElementSFD_; + using LayoutSFD = LayoutSFD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + constexpr static int SFVecSize = SFVecSize_; + constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockScaledGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + description_.SFD = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFVecSize = SFVecSize; + description_.EpilogueSFVecSize = EpilogueSFVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockScaledGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + using Sm1xxBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + auto A = cute::make_tensor(detail::make_iterator(static_cast(args.A)), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + + auto B = cute::make_tensor(detail::make_iterator(static_cast(args.B)), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(detail::make_iterator(static_cast(args.C)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(detail::make_iterator(static_cast(nullptr)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(detail::make_iterator(static_cast(args.D)), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + if constexpr (not is_same_v) { + + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + EpilogueSFVecSize + >; + + auto SfD = cute::make_tensor(detail::make_iterator(static_cast(args.SFD)), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD), Int, cutlass::reference::host::SfStrategy::SfDGen> + epilogue_params{alpha, beta, C, D, SfD, *(static_cast(args.norm_constant))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + else { + // W/O SF generation + auto SfD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L))); // not used. + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD)> + epilogue_params{alpha, beta, C, D, SfD}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + + return Status::kSuccess; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm_tn(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm(Manifest &manifest) { + /// + /// A is Row , B is Col + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + /// + /// A is Col , B is Row + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fd988f899f563acfc6f8003bdb49523bca51d6d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h @@ -0,0 +1,807 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename LayoutSFA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename LayoutSFB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockwiseGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_) + : SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockwiseGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFMVecSize = SFMVecSize; + description_.SFNVecSize = SFNVecSize; + description_.SFKVecSize = SFKVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockwiseGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>; + auto A = cute::make_tensor(static_cast(args.A), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto B = cute::make_tensor(static_cast(args.B), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(static_cast(args.C), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(static_cast(args.D), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + // W/O SF generation + cutlass::reference::host::GettEpilogueParams< + ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute, + decltype(C), decltype(D)> + epilogue_params{alpha, beta, C, D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + return Status::kSuccess; + } + +private: + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) { + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + +} + +template +void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) { + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..240fe18d16a27778bf75e0c02f99d251c096353f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h @@ -0,0 +1,636 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + Provider kProvider, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +struct ConvReferenceDispatcher; + +/// Dispatcher for Conv2d (partially specialized for kConvDim == 2) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 2, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv2dConfiguration const &config = + *static_cast(configuration); + + // TODO: make below code more general. It is fixed for NHWC now. + layout::TensorNHWC layout_a; + layout::TensorNHWC layout_b; + layout::TensorNHWC layout_c; + + layout_a.stride() = + make_Coord(int32_t(config.stride_a[0]), + int32_t(config.stride_a[1]), + int32_t(config.stride_a[2])); + + layout_b.stride() = + make_Coord(int32_t(config.stride_b[0]), + int32_t(config.stride_b[1]), + int32_t(config.stride_b[2])); + + layout_c.stride() = + make_Coord(int32_t(config.stride_c[0]), + int32_t(config.stride_c[1]), + int32_t(config.stride_c[2])); + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +/// Dispatcher for Conv3d (partially specialized for kConvDim == 3) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 3, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv3dConfiguration const &config = + *static_cast(configuration); + + ConvKind const conv_kind = ConvKindMap::kId; + + if (kProvider == Provider::kReferenceHost) { + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class ConvReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + static cutlass::conv::Operator const kConvolutionalOperator = ConvolutionalOperator; + static int const kConvDim = ConvDim; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + ConvDescription description_; + +public: + + /// Constructor + ConvReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d); + description_.conv_kind = ConvKindMap::kId; + description_.conv_dim = kConvDim; + + // Tensor description + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Iterator algorithm for convolution reference + description_.iterator_algorithm = IteratorAlgorithmID::kNone; + + // Compute capability for convolution reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind) + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + switch (kConvDim) { + case 2: + return sizeof(Conv2dConfiguration); + case 3: + return sizeof(Conv3dConfiguration); + default: + break; + } + + return 0; + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + ConvArguments const &args = *static_cast(arguments); + + ElementCompute alpha; + ElementCompute beta; + + alpha = *static_cast(args.alpha); + beta = *static_cast(args.beta); + + // TODO - respect pointer mode + + // Invoke 2D or 3D convolution + return detail::ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + kConvDim, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >::dispatch( + host_workspace, + static_cast(const_cast(args.A)), + static_cast(const_cast(args.B)), + static_cast(const_cast(args.C)), + static_cast(args.D), + alpha, + beta, + stream + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs Fprop reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_fprop(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Constructs Dgrad and Wgrad reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_backwards(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Six operators for the price of one. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_all(Manifest &manifest) { + + make_conv_fprop< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_conv_backwards< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..e07158b0602eef1d71cfdca95323b3da60553747 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h @@ -0,0 +1,543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for GEMM operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class GemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + static cutlass::ComplexTransform const kTransformA = TransformA; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + static cutlass::ComplexTransform const kTransformB = TransformB; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.transform_A = ComplexTransformMap::kId; + description_.B = make_TensorDescription(); + description_.transform_B = ComplexTransformMap::kId; + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + GemmUniversalConfiguration const &config = *static_cast(host_workspace); + GemmUniversalArguments const &args = *static_cast(arguments); + + TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; + TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; + TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; + TensorRefD ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + + cutlass::reference::device::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + + return Status::kErrorNotSupported; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new GemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new GemmReferenceOperation< + Provider::kReferenceDevice, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); +#endif +} + +/// Helper to create NN, NT, TN, and TT GEMM layouts. +template < + typename ElementA_, cutlass::ComplexTransform TransformA, + typename ElementB_, cutlass::ComplexTransform TransformB, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_canonical_layouts(Manifest &manifest) { + + // M Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + // N Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + + +/// Helper to create TN and interleaved layouts GEMM layouts. +template < + int InterleaveK, + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_interleaved_layouts(Manifest &manifest) { + + make_gemm< + ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + +} + +/// Helper to real-valued GEMM with canonical layouts +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_real_canonical_layouts(Manifest &manifest) { + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +// Helper to create all complex transformation permutations +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_complex_canonical_layouts(Manifest &manifest) { + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..01caa11e229ffd9109b0973dcca01064df448fa3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp @@ -0,0 +1,504 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor +#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter +#include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride +#include "gemm_operation_3x.hpp" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Limitation & Assumptions: +// 1. The tensor must be densely packed. That is, lda is k if the tensor is k-major, +// and lda is m if the tensor is m-major. +// 2. Circular buffer for tensorA and tensorE may have a less count compared to tensorB and others. +// This is because we can not get the problem_count information in the get_device_workspace_size(). +// But I can promise it will use at least 192MB memory if we enable circular buffer. +template +class SparseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementE = typename CollectiveMainloop::ElementE; + using LayoutE = typename CollectiveMainloop::LayoutE; + using SparseConfig = typename CollectiveMainloop::SparseConfig; + using LayoutATag = decltype(SparseConfig::deduce_layoutA_tag(typename CollectiveMainloop::LayoutA{})); + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig, + typename Operator::ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +public: + + /// Constructor + SparseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) {} + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + CompressorUtility const& compressor_utility, + void* device_a_compressed_ptr = nullptr, + void* device_e_ptr = nullptr) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_E = static_cast(device_e_ptr); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor(); + operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor(); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto problem_shape_MNKL = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + const int M = configuration->problem_size.m(); + const int N = configuration->problem_size.n(); + const int K = configuration->problem_size.k(); + const int L = configuration->batch_count; + using StrideA = typename CompressorUtility::StrideA; + auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + compressor_utility.set_problem_size(problem_shape_MNKL, dA); + auto status = update_arguments_(args, arguments, compressor_utility); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = problem_shape_MNKL; + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *) const override { + // Memory to hold operator + host_op_workspace_size = sizeof(Operator); + + // Memory to hold result of `.structure_sparse_zero_mask_fill()` + tensor_a_size = compressor_utility.get_raw_tensor_A_bytes(); + + // NOTE: order here is the order of workspace partition + const uint64_t size = host_op_workspace_size + tensor_a_size; + + return size; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr), compressor_utility); + if (status != Status::kSuccess) { + return 0; + } + + typename Compressor::Arguments compress_arguments { + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {/*Empty Not Use*/}, + {/*Empty Not Use*/} }; + + // Size for one iteration + // For multi-iteration, will need to multiply result of this function w/ actual problem_count + tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes(); + tensor_e_size = compressor_utility.get_tensor_E_bytes(); + device_op_workspace_size = Operator::get_workspace_size(args); + device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments); + + // NOTE: order here is the order of workspace partition + device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size; + + return device_per_iter_workspace_size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + return Status::kErrorInternal; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + + iter_idx.resize(static_cast(configuration)->device_count, 0); + + // Set problem_count. + problem_count = problem_count_from_profiler; + + // * Host Ptr + auto* host_op_workspace_ptr = reinterpret_cast(host_workspace); + auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size; + + // * Construct Op + Operator *op = new (host_op_workspace_ptr) Operator; + + // * Device Ptr (1st iteration) + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iter1 = static_cast(device_workspace); + auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; + auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; + auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; + auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size; + + // * Device A Raw Ptr + auto* device_a_raw_ptr = profiler_workspaces[0]; + + // * Random fill 50% of TensorA w/ zero following the structured sparse requirement + CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); + compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); + CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); + + CUDA_CHECK(cudaGetLastError()); + + // * Compress DTensorA and get DTensorAC & DTensorE + cutlass::KernelHardwareInfo hw_info; + CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {device_a_raw_ptr, + compressor_utility.dA, + device_a_compressed_ptr_iter1, + device_e_ptr_iter1}, + {hw_info} + }; + + cutlass::Status status {cutlass::Status::kSuccess }; + + Compressor compressor_op; + status = compressor_op.can_implement(arguments); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.run(stream); + if (status != Status::kSuccess) { + return status; + } + + // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE + for (int iter_i = 1; iter_i < problem_count; iter_i++) { + // * Device AC E Ptr per iteration + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + + CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); + } + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + CUDA_CHECK(cudaGetLastError()); + + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + + + const auto device_index = static_cast(arguments_ptr)->device_index; + + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; + + Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr ); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_op_workspace_ptr, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } + +private: + // Variables that must change in the const functions. + mutable CompressorUtility compressor_utility; + mutable int problem_count = 1; + mutable std::vector iter_idx; + + mutable uint64_t tensor_ac_size = 0; + mutable uint64_t tensor_e_size = 0; + mutable uint64_t tensor_a_size = 0; + mutable uint64_t host_op_workspace_size = 0; + mutable uint64_t device_compress_workspace_size = 0; + mutable uint64_t device_op_workspace_size = 0; + mutable uint64_t device_per_iter_workspace_size = 0; +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..c95d238a81f825dbbeae689ec452467cc8ca3afa --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h @@ -0,0 +1,382 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Symm operation kinds (Symm, Hemm) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/symm.h" +#include "cutlass/gemm/kernel/default_symm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + SymmDescription description_; + +public: + + /// Constructor + SymmOperationBase(char const *name = "unknown_symm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.symm_kind = SymmKind::kUniversal; + description_.side_mode = kSideModeA; + description_.fill_mode = kFillModeA; + description_.blas_mode = kBlasMode; + + description_.kind = OperationKind::kSymm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::SymmKernel::WarpCount::kM, + Operator::SymmKernel::WarpCount::kN, + Operator::SymmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the SYMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperation : public SymmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + SymmOperation(char const *name = "unknown_symm"): + SymmOperationBase(name) { + + this->description_.symm_kind = SymmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SymmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SymmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SymmConfiguration const *configuration = + static_cast(configuration_ptr); + + SymmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideModeA == SideMode::kLeft && + std::is_same::value) || + (kSideModeA == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "SymmOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..d419723791ace5d90eb7955223be9db72bbc2c3c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all TRMM operation kinds in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/trmm.h" +#include "cutlass/gemm/kernel/default_trmm_universal.h" +#include "cutlass/gemm/kernel/trmm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + TrmmDescription description_; + +public: + + /// Constructor + TrmmOperationBase(char const *name = "unknown_trmm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kTrmm; + description_.trmm_kind = TrmmKind::kUniversal; + description_.side_mode = kSideMode; + description_.fill_mode = kFillMode; + description_.diag_type = kDiagType; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::TrmmKernel::WarpCount::kM, + Operator::TrmmKernel::WarpCount::kN, + Operator::TrmmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + } + + /// Returns the description of the TRMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperation : public TrmmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + TrmmOperation(char const *name = "unknown_trmm"): + TrmmOperationBase(name) { + + this->description_.trmm_kind = TrmmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + TrmmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + TrmmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.ptr_D = arguments->D; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + TrmmConfiguration const *configuration = + static_cast(configuration_ptr); + + TrmmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideMode == SideMode::kLeft && + std::is_same::value) || + (kSideMode == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..5d500d9149bf645eadf8110d98612c40882d742c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h @@ -0,0 +1,330 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockScaledGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + DeviceAllocation *Computed_SFD{nullptr}; + DeviceAllocation *Reference_SFD{nullptr}; + DeviceAllocation *Norm_constant{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockScaledGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockScaledGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockScaledGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..c110de278cac640c1cedd8dd29d1b8ac09de81ef --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockwiseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + int64_t sf_vec_m{0}; + int64_t sf_vec_n{0}; + int64_t sf_vec_k{0}; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockwiseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockwiseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockwiseGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..683465f50cda19c8d505f2e66bcb60173d7e942d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv2dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv2dProblem { + + int64_t n, h, w, c, p, q, k, r, s; + int64_t groups; + int64_t pad_h, pad_w; + int64_t stride_h, stride_w; + int64_t dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv2dProblem::alpha, Conv2dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + void set_default_output_size() { + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; + case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *reordered_B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + // + // Methods + // + + Conv2dWorkspace() + : A(nullptr), + B(nullptr), + reordered_B(nullptr), + C(nullptr), + Computed(nullptr), + Reference(nullptr) {} + + // Set stride vector for tensor activations, filters, output + void set_stride_vector(Conv2dProblem const &problem, + library::ConvKind const &conv_kind, + library::LayoutTypeID const &layout_a, + library::LayoutTypeID const &layout_b, + library::LayoutTypeID const &layout_c) { + std::vector stride_activations; + std::vector stride_filters; + std::vector stride_output; + + // Strides for interleaved fprop + if (conv_kind == library::ConvKind::kFprop && + ((layout_a == library::LayoutTypeID::kTensorNC32HW32 && + layout_b == library::LayoutTypeID::kTensorC32RSK32 && + layout_c == library::LayoutTypeID::kTensorNC32HW32) || + (layout_a == library::LayoutTypeID::kTensorNC64HW64 && + layout_b == library::LayoutTypeID::kTensorC64RSK64 && + layout_c == library::LayoutTypeID::kTensorNC64HW64))) { + int interleave = + (layout_a == library::LayoutTypeID::kTensorNC32HW32) ? 32 : 64; + + stride_activations.push_back(int(problem.w) * interleave); + stride_activations.push_back(int(problem.w) * int(problem.h) * + interleave); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.k) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * + int(problem.r) * interleave); + + stride_output.push_back(int(problem.q) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } else { + // Strides for the rest cases + stride_activations.push_back(int(problem.c)); + stride_activations.push_back(int(problem.w) * int(problem.c)); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.r) * int(problem.s) * + int(problem.c / problem.groups)); + + stride_output.push_back(int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } + + switch (conv_kind) { + case library::ConvKind::kFprop: + configuration.stride_a = stride_activations; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_output; + + break; + case library::ConvKind::kDgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_activations; + + break; + case library::ConvKind::kWgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_activations; + configuration.stride_c = stride_filters; + + break; + default: + throw std::runtime_error( + "Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv2dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv2dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv2dOperationProfiler(); + + Conv2dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ac4abdef238b00f216053419620a60dfccfd5316 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h @@ -0,0 +1,449 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv3dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv3dProblem { + + int64_t n, d, h, w, c, z, p, q, k, t, r, s; + int64_t pad_d, pad_h, pad_w; + int64_t stride_d, stride_h, stride_w; + int64_t dilation_d, dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv3dProblem::alpha, Conv3dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + /// Infers output size from the input size, padding, stride, and dilation + void set_default_output_size() { + z = ((d + pad_d - t * dilation_d) / stride_d) + 1; + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * z * p * q), int(k), int(t * r * s * c)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * d * h * w), int(c), int(t * r * s * k)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(t * r * s * c), int(n * z * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(z), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kDgrad: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(d), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(t), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv3dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + + // + // Methods + // + + Conv2dWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + + // Returns stride vector for tensor A + std::vector stride_a(library::ConvKind const &conv_kind) { + return { + configuration.layout_a(conv_kind).stride()[0], + configuration.layout_a(conv_kind).stride()[1], + configuration.layout_a(conv_kind).stride()[2], + configuration.layout_a(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor B + std::vector stride_b(library::ConvKind const &conv_kind) { + + return { + configuration.layout_b(conv_kind).stride()[0], + configuration.layout_b(conv_kind).stride()[1], + configuration.layout_b(conv_kind).stride()[2], + configuration.layout_b(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor C + std::vector stride_c(library::ConvKind const &conv_kind) { + + return { + configuration.layout_c(conv_kind).stride()[0], + configuration.layout_c(conv_kind).stride()[1], + configuration.layout_c(conv_kind).stride()[2], + configuration.layout_c(conv_kind).stride()[3] + }; + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv3dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv3dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv3dOperationProfiler(); + + Conv3dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Updates the arguments structure for the CUTLASS operator based on + /// the problem index. + void set_cutlass_operator_arguments_(int problem_idx = 0); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..873ba1abe03c05df29edc032ea3f1ffd2f19c3ee --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h @@ -0,0 +1,456 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuBLAS. +*/ + +#pragma once + +#if CUTLASS_ENABLE_CUBLAS +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/blas3.h" + +#include "options.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Converts a cuBLAS status to cutlass::Status +Status get_cutlass_status(cublasStatus_t cublas); + +/// Converts a cuBLAS status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cublasStatus_t cublas_status); + +/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation +bool get_cublas_transpose_operation( + cublasOperation_t &operation, + library::LayoutTypeID layout, + library::ComplexTransform transform = library::ComplexTransform::kNone); + +/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration +bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); + +/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class +cublasGemmAlgo_t get_cublas_gemm_algo( + int cta_m, + int cta_n, + int cta_k, + library::OpcodeClassID opcode_class); + +/// Returns a status if cuBLAS can satisfy a particular GEMM description +Status cublas_satisfies(library::GemmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular RankK description +Status cublas_satisfies(library::RankKDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular TRMM description +Status cublas_satisfies(library::TrmmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description +Status cublas_satisfies(library::SymmDescription const &desc); + +/// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and +/// to destroy cublasHandle_t on CublasCreate object destruction. +/// Additionally, it provides implicit cast from CublasCreate's object to cublasHandle_t's object +class CublasCreate { +private: + cublasHandle_t handle; + cublasStatus_t status; + +public: + CublasCreate() { + status = cublasCreate(&handle); + } + + ~CublasCreate() { + cublasDestroy(handle); + } + + /// Implicit cast CublasCreate object to cublasHandle_t + operator cublasHandle_t() const { return handle; } + + /// returns cublasStatus_t for handle creation + cublasStatus_t get_cublas_create_status() { return status; } +}; + +/// This is a helper class to create cublasLtHandle_t automatically on CublasLtCreate object creation and +/// to destroy cublasLtHandle_t on CublasLtCreate object destruction. +/// Additionally, it provides implicit cast from CublasLtCreate's object to cublasLtHandle_t's object +class CublasLtCreate { +private: + cublasLtHandle_t handle; + cublasStatus_t status; + +public: + CublasLtCreate() { + status = cublasLtCreate(&handle); + } + + ~CublasLtCreate() { + cublasLtDestroy(handle); + } + + /// Implicit cast CublasLtCreate object to cublasLtHandle_t + operator cublasLtHandle_t() const { return handle; } + + /// returns cublasLtStatus_t for handle creation + cublasStatus_t get_cublaslt_create_status() { return status; } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Selects one or more cuBLAS algorithms. +static void select_cublas_algorithms( + std::vector &algorithms, + Options const &options, + library::GemmDescription const &op_desc) { + + library::OpcodeClassID const & opcode_class = + op_desc.tile_description.math_instruction.opcode_class; + + switch (options.library.algorithm_mode) { + case AlgorithmMode::kMatching: + { + algorithms.push_back(get_cublas_gemm_algo( + op_desc.tile_description.threadblock_shape.m(), + op_desc.tile_description.threadblock_shape.n(), + op_desc.tile_description.threadblock_shape.k(), + opcode_class)); + break; + } + + case AlgorithmMode::kBest: + { + // Choose first enumerated mode. If none are enumerated, choose based on opcode class + // and evaluate all of them. + + if (options.library.algorithms.empty()) { + // Enumerate all algorithms + if (opcode_class == library::OpcodeClassID::kSimt) { + + for (int algo = CUBLAS_GEMM_DEFAULT; + algo <= CUBLAS_GEMM_ALGO23; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + else { + + for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + } + else { + // Use the listed algorithms + algorithms.reserve(options.library.algorithms.size()); + + for (int algo : options.library.algorithms) { + algorithms.push_back(reinterpret_cast(algo)); + } + } + + break; + } + + case AlgorithmMode::kDefault: + { + + // Use the library's default algorithm + algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? + CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + break; + } + default: + { + break; + } + } +} + +/// Dispatcher to cublasGemmEx() +struct cublasGemmExDispatcher { + + // + // Data members + // + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + cublasGemmAlgo_t algo; + Status status; + + // + // Methods + // + + cublasGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_, + cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/// Dispatcher to cublaslt kernels +// +struct cublasLtGemmExDispatcher { + + // + // Data members + // + library::GemmDescription const &op_desc; + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type = CUDA_R_32F; + + //cublasLt-specific data structures + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; + cublasLtMatmulPreference_t preference = NULL; + + //is set by call to get_cublaslt_algo() + cublasLtMatmulHeuristicResult_t heuristicResult_; + void *workspace = nullptr; + + Status status; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + // + // Methods + // + + cublasLtGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_ + ); + + /// Initialize the cublasLt variables + void initialize_cublaslt(); + + + /// Runs auto-tuning for the cublas heuristics + bool get_cublaslt_algo(cublasLtHandle_t handle, + AlgorithmMode algorithm_mode + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasLtHandle_t handle, cudaStream_t stream = nullptr); + + ~cublasLtGemmExDispatcher(){ + + // descriptors are no longer needed as all GPU work was already enqueued + if (preference) cublasLtMatmulPreferenceDestroy(preference); + if (Ddesc) cublasLtMatrixLayoutDestroy(Ddesc); + if (Cdesc) cublasLtMatrixLayoutDestroy(Cdesc); + if (Bdesc) cublasLtMatrixLayoutDestroy(Bdesc); + if (Adesc) cublasLtMatrixLayoutDestroy(Adesc); + if (operationDesc) cublasLtMatmulDescDestroy(operationDesc); + + if (workspace) { + cudaFree(workspace); + } + + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas rank k update kernels +struct cublasRankKDispatcher { + + // + // Data members + // + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + int num_ranks; //(rank-k or rank-2k) + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasRankKDispatcher( + library::RankKDescription const &op_desc, + library::RankKConfiguration configuration_, + library::RankKArguments arguments_ + ); + + /// Executes RankK using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublasTrmm() +struct cublasTrmmDispatcher { + + // + // Data members + // + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasSideMode_t side; + cublasFillMode_t uplo; + cublasDiagType_t diag; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_D; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + Status status; + + // + // Methods + // + + cublasTrmmDispatcher( + library::TrmmDescription const &op_desc, + library::TrmmConfiguration configuration_, + library::TrmmArguments arguments_ + ); + + /// Executes TRMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas symm/hemm update kernels +struct cublasSymmDispatcher { + + // + // Data members + // + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasSideMode_t side; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasSymmDispatcher( + library::SymmDescription const &op_desc, + library::SymmConfiguration configuration_, + library::SymmArguments arguments_ + ); + + /// Executes Symm using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +} // namespace profiler +} // namespace cutlass + + +#endif // #if CUTLASS_ENABLE_CUBLAS diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..7ce9eea5a883fa4c5732f5d8aec120a99064bac0 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h @@ -0,0 +1,590 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuDNN. + +*/ + +#pragma once +#if CUTLASS_ENABLE_CUDNN +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/library/library.h" +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Converts a cuDNN status to cutlass::Status +Status get_cutlass_status(cudnnStatus_t cudnn_status); + +/// Converts a cuDNN status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); + +/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception +Status checkCudnnErr(cudnnStatus_t cudnn_status); + +/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration +bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode); + +/// Maps a CUTLASS layout type to a cuDNN data type enumeration +bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout); + +/// Maps a CUTLASS numeric type to a cuDNN data type enumeration +bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type); + +/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type +bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc); + +/// Returns a status if cudnn can satisfy a particular Conv2d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration); + +/// Returns a status if cudnn can satisfy a particular Conv3d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration); + +/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue) +float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src); + + +/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and +/// to destroy cudnnHandle_t on CudnnCreate object destruction. +/// Additionally, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object +class CudnnCreate { +private: + cudnnHandle_t handle; + cudnnStatus_t status; + +public: + CudnnCreate() { + status = cudnnCreate(&handle); + } + + ~CudnnCreate() { + cudnnDestroy(handle); + } + + /// Implicit cast CudnnCreate object to cudnnHandle_t + operator cudnnHandle_t() const { return handle; } + + /// returns cudnnStatus_t for handle creation + cudnnStatus_t get_cudnn_create_status() { return status; } +}; + + +namespace detail { + +/// Dispatcher to cudnn convolution operators +struct cudnnConvDispatcher { + + // + // Data members + // + //library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + library::ConvKind conv_kind; + + // cudnn-specific data structures to fill cudnn API call arguments + // cudnn activation, filter, and output descriptors + cudnnTensorDescriptor_t activation_desc; + cudnnFilterDescriptor_t filter_desc; + cudnnTensorDescriptor_t output_desc; + cudnnConvolutionDescriptor_t conv_desc; + + // cudnn datatypes + cudnnDataType_t data_type_activation; + cudnnDataType_t data_type_filter; + cudnnDataType_t data_type_output; + + // cudnn layouts + cudnnTensorFormat_t layout_activation; + cudnnTensorFormat_t layout_filter; + cudnnTensorFormat_t layout_output; + + // cudnn convolution mode + cudnnConvolutionMode_t conv_mode; + + // cudnn math type (tensorop, tensorop with conversion, simt) + cudnnMathType_t math_type; + + // cudnn compute data type + cudnnDataType_t compute_type; + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + float alpha; + float beta; + + // cudnn workspace + size_t workspace_size_in_bytes = 0; + cutlass::device_memory::allocation workspace; + + // select cudnn's implicit gemm precomputed algorithm with tensor operations + static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + + Status status; + + // + // Methods + // + + // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfiguration + + // ctor for conv2d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv2dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + // Get cudnn mathtype (cudnnMathType_t) + good = (good && get_cudnn_mathtype(math_type, op_desc)); + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + return; + } + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation, filter, and output descriptor + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + activation_desc, + layout_activation, + data_type_activation, + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.H, + configuration.problem_size.W + )); + + status = get_cutlass_status( + cudnnSetFilter4dDescriptor( + filter_desc, + data_type_filter, + layout_filter, + configuration.problem_size.K, + configuration.problem_size.C / configuration.problem_size.groups, + configuration.problem_size.R, + configuration.problem_size.S + )); + + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + output_desc, + layout_output, + data_type_output, + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.P, + configuration.problem_size.Q + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + + // ctor for conv3d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv3dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + } + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation descriptor + std::vector activation_extent { + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.D, + configuration.problem_size.H, + configuration.problem_size.W + }; + + std::vector activation_stride { + configuration.layout_activations.stride()[3], + 1, + configuration.layout_activations.stride()[2], + configuration.layout_activations.stride()[1], + configuration.layout_activations.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + activation_desc, + data_type_activation, + op_desc.conv_dim + 2, + activation_extent.data(), + activation_stride.data() + )); + + // Set filter descriptor + std::vector filter_extent { + configuration.problem_size.K, + configuration.problem_size.C, + configuration.problem_size.T, + configuration.problem_size.R, + configuration.problem_size.S + }; + + std::vector filter_stride { + configuration.layout_filters.stride()[3], + 1, + configuration.layout_filters.stride()[2], + configuration.layout_filters.stride()[1], + configuration.layout_filters.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetFilterNdDescriptor( + filter_desc, + data_type_filter, + layout_filter, + op_desc.conv_dim + 2, + filter_extent.data() + )); + + + // Set output descriptor + std::vector output_extent { + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.Z, + configuration.problem_size.P, + configuration.problem_size.Q + }; + + std::vector output_stride { + configuration.layout_output.stride()[3], + 1, + configuration.layout_output.stride()[2], + configuration.layout_output.stride()[1], + configuration.layout_output.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + output_desc, + data_type_output, + op_desc.conv_dim + 2, + output_extent.data(), + output_stride.data() + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + /// Executes Conv2d operator from cudnn library + cudnnStatus_t operator()(cudnnHandle_t handle) { + + switch (conv_kind) { + case library::ConvKind::kFprop: + return cudnnConvolutionForward( + handle, + &alpha, + activation_desc, + activation(), + filter_desc, + filter(), + conv_desc, + fprop_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + output_desc, + arguments.D + ); + case library::ConvKind::kDgrad: + return cudnnConvolutionBackwardData( + handle, + &alpha, + filter_desc, + filter(), + output_desc, + output(), + conv_desc, + dgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + activation_desc, + arguments.D + ); + case library::ConvKind::kWgrad: + return cudnnConvolutionBackwardFilter( + handle, + &alpha, + activation_desc, + activation(), + output_desc, + output(), + conv_desc, + wgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + filter_desc, + arguments.D + ); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Activation Tensor + void const * activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.A; + case library::ConvKind::kDgrad : return arguments.C; + case library::ConvKind::kWgrad : return arguments.B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter Tensor + void const *filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.B; + case library::ConvKind::kDgrad : return arguments.B; + case library::ConvKind::kWgrad : return arguments.C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output Tensor + void const *output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.C; + case library::ConvKind::kDgrad : return arguments.A; + case library::ConvKind::kWgrad : return arguments.A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +} // namespace detail +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif //#if CUTLASS_ENABLE_CUDNN +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..be82245325cebb147e2c801965a52ece91395cb2 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +#include "options.h" +#include "operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CUTLASS Profiler application +class CutlassProfiler { +private: + + // + // Data members + // + + /// Performance testbench options + Options options_; + + /// Entry points for each operation + OperationProfilerVector operation_profilers_; + +private: + + /// Prints usage + void print_usage_(std::ostream &); + + /// Prints usage + void print_options_(std::ostream &); + + /// Enumerates all operations + void enumerate_(); + + /// Profiles all operations + int profile_(); + +public: + + CutlassProfiler(Options const &options); + ~CutlassProfiler(); + + /// Invokes profiling operations + int operator()(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..98f1fdc3044501e456c927471b30d74b09eafd39 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h @@ -0,0 +1,56 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include + +//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } +//#define report(x) {} + +// Enable/Disable Profiler debug prints +//#define DEBUG_PROFILER + +//RED 31m // profiler prints debug messages in red +//YELLOW 33m // ir prints debug messages in yellow + +#ifndef DEBUG_PROFILER +#define debugprof(...) +#else +#define debugprof(...) do { \ + printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ + printf(__VA_ARGS__); \ + printf("\033[0m\n"); \ + } while (0) +#endif diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..488b635c2ec233e3027303bbf15a34f375a438fd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/util/distribution.h" + +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device memory allocation +class DeviceAllocation { +private: + + /// Data type of contained elements + library::NumericTypeID type_; + + /// Gets the stride between elements + size_t batch_stride_; + + /// Capacity in elements of device allocation + size_t capacity_; + + /// Pointer to device memory + void *pointer_; + + /// Layout type ID + library::LayoutTypeID layout_; + + /// Stride vector + std::vector stride_; + + /// Extent vector + std::vector extent_; + + /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory + int batch_count_; + + /// Buffer holding TensorRef instance to recently allocated memory + std::vector tensor_ref_buffer_; + + /// The device ID where the allocation is made + int device_; + +public: + // + // Static member functions + // + + /// Determines the number of bytes needed to represent this numeric type + static size_t bytes(library::NumericTypeID type, size_t capacity); + + /// Returns the stride of a packed layout + static std::vector get_packed_layout( + library::LayoutTypeID layout_id, + std::vector const &extent); + + /// returns the capacity needed + static size_t construct_layout( + void *bytes, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector &stride); + + /// Returns true if two blocks have exactly the same value + static bool block_compare_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity); + + /// Returns true if two blocks have approximately the same value + static bool block_compare_relatively_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity, + double epsilon, + double nonzero_floor); + +public: + // + // Methods + // + + DeviceAllocation(); + + DeviceAllocation( + library::NumericTypeID type, + size_t capacity, + int device = -1); + + DeviceAllocation( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1, + int device = -1); + + ~DeviceAllocation(); + + DeviceAllocation &reset(); + + /// Allocates device memory of a given type and capacity + DeviceAllocation &reset(library::NumericTypeID type, size_t capacity); + + /// Allocates memory for a given layout and tensor + DeviceAllocation &reset( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1); + + /// Returns a buffer owning the tensor reference + std::vector &tensor_ref() { + return tensor_ref_buffer_; + } + + bool good() const; + + /// Data type of contained elements + library::NumericTypeID type() const; + + /// Pointer to start of device memory allocation + void *data() const; + + /// Pointer to the first element of a batch + void *batch_data(int batch_idx) const; + + /// Gets the layout type + library::LayoutTypeID layout() const; + + /// Gets the stride vector + std::vector const & stride() const; + + /// Gets the extent vector + std::vector const & extent() const; + + /// Gets the number of adjacent tensors in memory + int batch_count() const; + + /// Gets the stride (in units of elements) between items + int64_t batch_stride() const; + + /// Gets the stride (in units of bytes) between items + int64_t batch_stride_bytes() const; + + /// Capacity of allocation in number of elements + size_t capacity() const; + + /// Capacity of allocation in bytes + size_t bytes() const; + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_device(int seed, Distribution dist); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_host(int seed, Distribution dist); + + /// Initializes a device allocation to a sequential distribution + void initialize_sequential_device(Distribution dist); + + /// Initializes a host allocation to a sequential distribution + void initialize_sequential_host(Distribution dist); + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); + + /// Uniformly fills a tensor with a value when provided o.w. zero + void fill_device(double value); + + /// Uniformly fills a host allocation with a value when provided o.w. zero + void fill_host(double value); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_device(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_host(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_to_host(void *ptr); + + /// Writes a tensor to csv + void write_tensor_csv(std::ostream &out); + +private: + /// A wrapper that sets the device, performs malloc, and sets back + cudaError_t malloc(void** ptr, size_t size); +}; + +using DeviceAllocationList = std::list; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h new file mode 100644 index 0000000000000000000000000000000000000000..0443b340397426bfafc812c1a4b9179fc6af0de4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include +#include + + +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" + +#include "options.h" +#include "device_allocation.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collection of allocations on the device +class DeviceContext { +public: + + // + // Type definitions + // + using AllocationMap = std::map; + +private: + // + // Data members + // + + /// Memory allocations that exist (owning) + DeviceAllocationList device_memory_; + + /// Non-owning set of named allocations + AllocationMap allocations_; + +public: + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_block( + Options const &options, + std::string const &name, + library::NumericTypeID type, + size_t capacity, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_and_initialize_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Allocates memory for sparse meta data + DeviceAllocation *allocate_and_initialize_sparsemeta_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + library::NumericTypeID type_a, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Clears named allocations (but does not necessarily free memory) + void clear(); + + /// Frees all device memory allocations + void free(); + + /// Gets the allocation by name + DeviceAllocation &at(std::string const &name); + + size_t size() const; + + AllocationMap::iterator begin(); + AllocationMap::iterator end(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h new file mode 100644 index 0000000000000000000000000000000000000000..897311c228ce76c4e8814ce996929561d44d2465 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/library/library.h" + +#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +T from_string(std::string const &); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing how the performance testbench evaluates kernels. +enum class ExecutionMode { + kProfile, ///< regular verification and profiling + kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched + kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations + kTrace, ///< executes a single device-side computation with no other kernel launches + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(ExecutionMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +ExecutionMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Library algorithm mode +enum class AlgorithmMode { + kMatching, ///< compare against best matching algorithm + kBest, ///< evaluate all library algorithms and report best + kDefault, ///< use the library's default algorithm option + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(AlgorithmMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +AlgorithmMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Outcome of a performance test +enum class Disposition { + kPassed, + kFailed, // kernel itself reported an error + kNotRun, + kIncorrect, // kernel finished without a detected error, but result does not equal expected result + kNotVerified, + kInvalidProblem, + kNotSupported, + kInvalid +}; + +/// Converts a Disposition enumerant to a string +char const *to_string(Disposition disposition, bool pretty = false); + +/// Parses a Disposition enumerant from a string +template <> +Disposition from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates when to save +enum class SaveWorkspace { + kNever, + kIncorrect, + kAlways, + kInvalid +}; + +/// Converts a SaveWorkspace enumerant to a string +char const *to_string(SaveWorkspace save_option, bool pretty = false); + +/// Parses a SaveWorkspace enumerant from a string +template <> +SaveWorkspace from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates the type of kernel argument +// ArgumentType can be both ScalarType or NumericType. Thus, enums kScalar and kNumeric +// 1) kScalar: e.g. of a Scalar ArgumentType is u32 is a Scalar type. +// Its c++ equivalent as "type name = initializer" is "u32 m = 32" +// 2) kNumeric: e.g. of a Numeric ArgumentType is NumericTypeID is a Numeric type. +// Its c++ equivalent as "type name = initializer" is "NumericTypeID numeric_type = u32" +enum class ArgumentTypeID { + kScalar, + kInteger, + kTensor, + kBatchedTensor, + kStructure, + kEnumerated, + kInvalid +}; + +/// Converts a ArgumentTypeID enumerant to a string +char const *to_string(ArgumentTypeID type, bool pretty = false); + +/// Parses a ArgumentTypeID enumerant from a string +template <> +ArgumentTypeID from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Profiler typedefs +using ProviderVector = std::vector; +using DispositionMap = std::map; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Print vector for the report +template +std::ostream& operator<< (std::ostream& out, const std::vector& v) { + for (size_t i = 0; i < v.size(); ++i) { + out << to_string(v[i], true) << (i + 1u != v.size() ? "," : ""); + } + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..faf317152473cac6dc62ecf8970cd1acfb2c1622 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -0,0 +1,333 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Gemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + + bool enable_sm90_mixed_dtype_shuffle_test{false}; + + // + // Methods + // + + /// Parses the problem + Status parse( + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::GemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// For mixed input dtype kernels + DeviceAllocation *Scale{nullptr}; // Scale tensor + DeviceAllocation *Zero{nullptr}; // Zero tensor + DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + std::vector gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + GemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~GemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..154045295d6443d930ba53387366f4b8abe408a4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GpuTimer { + + cudaEvent_t events[2]; + + // + // Methods + // + + GpuTimer(); + + GpuTimer(GpuTimer const&) = delete; + + GpuTimer(GpuTimer &&gpu_timer) noexcept; + + ~GpuTimer(); + + /// Records a start event in the stream, the flag is for cudaEventRecordWithFlags + void start(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags + void stop(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags + void stop_and_wait(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Returns the duration in milliseconds + double duration(int iterations = 1) const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..62d47990584cbb984935a00a267cff15dbb4f4e5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief GroupedGemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "device_context.h" +#include "operation_profiler.h" +#include "options.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GroupedGemmOperationProfiler : public OperationProfiler { +public: + /// Problem structure obtained from problem space + struct GroupedGemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped}; + + std::vector problem_sizes; + std::vector> problem_sizes_3x; + + /// For exploration purposes + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + std::vector lda{0}; + std::vector ldb{0}; + std::vector ldc{0}; + + std::vector alpha; + std::vector beta; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + bool use_pdl{false}; + + /// Parses the problem + Status parse( + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); }; + int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); }; + int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); }; + + /// Total number of bytes loaded + int64_t bytes(library::GroupedGemmDescription const& operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GroupedGemmDescription const& operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult& result, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + }; + + struct BlockScalingWorkspace { + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector SFA_ptr_array_device; + std::vector SFB_ptr_array_device; + std::vector SFC_ptr_array_device; + std::vector SFD_ptr_array_device; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector SFA_ptr_array_host; + std::vector SFB_ptr_array_host; + std::vector SFC_ptr_array_host; + std::vector SFD_ptr_array_host; + std::vector SFD_reference_ptr_array_host; + + // matrix wide constant, not per-batch or per-group + DeviceAllocation* norm_constant; + }; + + // workspace contains the allocated blocks, arguments just contain the raw + // pointers + struct GroupedGemmWorkspace { + + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector A_ptr_array_device; + std::vector B_ptr_array_device; + std::vector C_ptr_array_device; + std::vector D_ptr_array_device; + std::vector reference_ptr_array_host; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector A_ptr_array_host; + std::vector B_ptr_array_host; + std::vector C_ptr_array_host; + std::vector D_ptr_array_host; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + /// *NOT* the number of groups in the grouped GEMM (we use `num_groups` in the profiler) + int problem_count{1}; + + DeviceAllocation* problem_sizes_array_device{nullptr}; + DeviceAllocation* problem_sizes_3x_array_device{nullptr}; + DeviceAllocation* lda_array_device{nullptr}; + DeviceAllocation* ldb_array_device{nullptr}; + DeviceAllocation* ldc_array_device{nullptr}; + DeviceAllocation* ldd_array_device{nullptr}; + + std::optional block_scales; + + library::GemmGroupedConfiguration configuration; + library::GroupedGemmBlockScaledArguments arguments; + + std::vector host_workspace; + DeviceAllocation device_workspace; + + cudaStream_t stream; + }; + +private: + void init_arguments(Options const& options) { + auto& arguments = gemm_workspace_.arguments; + // these get updated in each profiler run to ensure L2 cycling + arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data(); + arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data(); + arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data(); + arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data(); + + arguments.alpha = problem_.alpha.data(); + arguments.beta = problem_.beta.data(); + arguments.pointer_mode = library::ScalarPointerMode::kHost; + arguments.lda = static_cast(gemm_workspace_.lda_array_device->data()); + arguments.ldb = static_cast(gemm_workspace_.ldb_array_device->data()); + arguments.ldc = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.ldd = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.problem_sizes = + static_cast(gemm_workspace_.problem_sizes_array_device->data()); + arguments.problem_sizes_3x = static_cast*>( + gemm_workspace_.problem_sizes_3x_array_device->data()); + gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data(); + gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size(); + gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + arguments.sm_count = options.device.get_sm_count(0); + if (is_block_scaled) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + arguments.SFD = block_scaled_ws.SFD_ptr_array_device[0]->data(); + arguments.norm_constant = block_scaled_ws.norm_constant->data(); + } + else if (is_blockwise) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + } + } + +protected: + /// GEMM problem obtained from problem space + GroupedGemmProblem problem_; + + /// Device memory allocations + GroupedGemmWorkspace gemm_workspace_; + + bool is_block_scaled{false}; + bool is_blockwise{false}; + +public: + GroupedGemmOperationProfiler(Options const& options); + + virtual ~GroupedGemmOperationProfiler(); + + GroupedGemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream& out) const; + + /// Prints examples + virtual void print_examples(std::ostream& out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Measures performance results + virtual bool profile( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +protected: + /// Initializes the performance result + void initialize_result_( + PerformanceResult& result, + Options const& options, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GroupedGemmWorkspace &gemm_workspace, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration for exploration parameters + void update_workspace_and_result_( + GroupedGemmWorkspace &gemm_workspace, + PerformanceResult &result, + ProblemSpace const &problem_space, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult& result, + Options const& options, + library::Operation const* operation, + void* arguments, + void* host_workspace, + void* device_workspace) override; + + /// Method to profile a CUTLASS Operation for the best configuration for a fixed shape + bool profile_cutlass_for_fixed_shape_( + Options const& options, + library::Operation const* operation, + ProblemSpace const& problem_space); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..446ef2c16739b28aaf038ca62bad6e3cdf667813 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h @@ -0,0 +1,287 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include +#include +#include + +// CUTLASS includes +#include "cutlass/trace.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "performance_result.h" +#include "performance_report.h" +#include "problem_space.h" +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class OperationProfiler { +public: + + +protected: + // + // Data members + // + + /// Top-level operation kind + library::OperationKind kind_; + + /// Human readable description + std::string description_; + + /// Arguments parsed from command line + ArgumentDescriptionVector arguments_; + + /// List of providers used to verify and compare each result + ProviderVector verification_providers_; + + /// Model performance result initialized by the operation profiler with workload statistics + /// and reasonable default state. + PerformanceResult model_result_; + + /// Performance result vector constructed by profiling the operation + PerformanceResultVector results_; + +public: + + // + // Methods + // + + /// Ctor + OperationProfiler(); + + OperationProfiler( + Options const &options, + library::OperationKind kind, + ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), + ProviderVector const & verification_providers = ProviderVector()); + + /// Destructor + virtual ~OperationProfiler(); + + /// Obtains the operation kind + library::OperationKind kind() const { return kind_; } + + /// Gets the schema description + std::string const &description() const; + + /// Returns a reference to the arguments + ArgumentDescriptionVector const &arguments() const { return arguments_; } + +public: + + // + // Basic overrides + // + + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const =0; + + /// Entry point to profile all operations in the manifest + virtual int profile_all( + Options const &options, + library::Manifest const &manifest, + DeviceContext &device_context); + +public: + + // + // Operation-specific phases of verification and profiling + // + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + +public: + + // + // Static helpers + // + + /// Sleep for a given duration in ms + static void sleep(int sleep_duration); + + /// Returns true if the current operation description satisfies the problem space + static bool satisfies( + library::OperationDescription const &op_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Compares tensors for equality + static Disposition compare_tensors( + Options const &options, + DeviceAllocation &experimental, + DeviceAllocation &reference, + int64_t count = 0); + + static void save_workspace( + DeviceContext &device_context, + Options const &options, + library::OperationDescription const &desc, + library::Provider provider, + library::Provider verification_provider = library::Provider::kInvalid); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + std::string const &value); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + int64_t value); + +protected: + + /// Sets operation description + static void initialize_result_( + PerformanceResult &result, + library::OperationDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Profiles the GPU kernel launched in `func` running simultaneously on all + /// requested devices. + Status profile_kernel_w_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_no_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + +private: + /// finds string matches filter_string in operation_name + bool find_string_matches_( + std::string const &filter_string, + std::string const &operation_name); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Vector of owning operation profilers +using OperationProfilerVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h new file mode 100644 index 0000000000000000000000000000000000000000..1a957b36eea35f7c0a5366645c3a62298ca56dea --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Command line options for performance test program +*/ + +#pragma once + +#include +#include +#include + +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/library/library.h" + +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Global options +class Options { +public: + + /// Cublas and cuDNN options + struct Library { + + // + // Data members + // + + /// Algorithm mode + AlgorithmMode algorithm_mode; + + /// Algorithm enumerants + std::vector algorithms; + + // + // Methods + // + + explicit Library(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to the selected device + struct Device { + + /// Device ID + std::vector devices; + + /// Number of total devices + /// This is not set by the user, it is set by automatically + int num_devices; + + /// CUDA Device properties + std::vector properties; + + /// Total memory allocation on each device + size_t maximum_capacity; + + private: + /// SM Count + /// Limits the number of SMs to use on each device + int sm_count; + + // + // Methods + // + public: + explicit Device(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + void print_device_info(std::ostream &out) const; + + /// Returns the device ID from a device index + int device_id(size_t device_index) const; + + /// Returns the sm_count if set, otherwise returns the number of SMs on the device + int get_sm_count(int device_index) const; + + /// Returns the compute capability of the listed devices (e.g. 70, 75, 80, etc.) + int compute_capability(int device_index) const; + }; + + /// Options related to initializing input tensors + struct Initialization { + + /// If true, data is initialized randomly. If false, no initialization is performed after + /// allocating tensors. + bool enabled; + + /// If true, data distribution is set by the user and is not allowed to change + /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) + bool fix_data_distribution; + + /// Data distribution for input tensors + Distribution data_distribution; + + /// Source of random tensor elements + library::Provider provider; + + /// Random number generator seed. + int seed; + + // + // Methods + // + + explicit Initialization(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Helper to parse a Distribution object from the command line parser + static void get_distribution( + cutlass::CommandLine const &args, + std::string const &arg, + cutlass::Distribution &dist); + }; + + /// Options related to verification of the result + struct Verification { + + // + // Data members + // + + /// If true, kernels are verified before they are profiled + bool enabled; + + /// If true, causes profiler to return an error code if no reference check is run. + /// Only valid when verification is enabled. + bool required; + + /// Relative error threshold - zero to require bit-level consistency + double epsilon; + + /// Values smaller than this are assumed to be zero + double nonzero_floor; + + /// List of providers used to verify each result + ProviderVector providers; + + /// Indicates when to save the workspace + SaveWorkspace save_workspace; + + // + // Methods + // + + explicit Verification(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to profiling + struct Profiling { + + /// Number of workspaces to rotate through to avoid cache-resident working sets + int workspace_count{0}; + + /// Number of iterations to warmup each kernel prior to profiling + int warmup_iterations{10}; + + /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration + /// This will always override profiling-duration and min-iterations. + int iterations{100}; + + /// Time to spend profiling each kernel (ms) + int duration{10}; + + /// Minimum number of iterations to profile + int min_iterations{10}; + + /// If true, profiling with cuda graph enabled. + bool use_cuda_graphs{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// within the subset of kernels matching a kernel filter regex. The best + /// performance is determined by screening over a set of predefined M/N/K + /// sizes and performance-related parameters, including cluster shapes, + /// swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_kernel_performance_search{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// for a given M/N/K problem size by evaluating various performance-related + /// parameters such as cluster shapes, swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_best_kernel_for_fixed_shape{false}; + + /// Number of ms to sleep between profiling periods (ms) + int sleep_duration{50}; + + /// If true, profiling is actually conducted. + bool enabled{true}; + + /// If true, profiling returns an error code if no kernels are found to match the filters. + bool error_on_no_match{false}; + + /// If true, profiling returns an error code if no kernel are profiled + // Sometimes the kernel matches but failed to profile (e.g. can_implement() error) + bool error_if_nothing_is_profiled{false}; + + /// List of providers of each functionality to be profiled + ProviderVector providers; + + // + // Methods + // + + explicit Profiling(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to reporting + struct Report { + + /// If true, result is appended to possibly existing file + bool append; + + /// Path to a file containing results + std::string output_path; + + /// Path to a file containing junit xml results + std::string junit_output_path; + + /// Sequence of tags to attach to each result + std::vector> pivot_tags; + + /// If true, reports status of all kernels including those that were + /// not run for the given arguments + bool report_not_run; + + /// Prints human-readable text to stdout. If false, nothing is written to stdout + bool verbose; + + /// Sort results by flops-per-byte + bool sort_flops_per_byte; + + /// Sort results by flops-per-second + bool sort_flops_per_sec; + + /// Prints the name of the kernel being profiled before running the kernel. + /// This is useful for determining which kernel is causing a run of the profiler to hang + bool print_kernel_before_running; + + // + // Methods + // + + explicit Report(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to printing usage and version information + struct About { + + /// If true, usage is printed and the program ends. + bool help; + + /// Prints version string + bool version; + + /// Print information about devices + bool device_info; + + // + // Methods + // + + explicit About(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + static void print_version(std::ostream &out); + }; + +public: + + // + // Data members + // + + /// Top-level execution mode + ExecutionMode execution_mode; + + /// Name of math function to profile + library::OperationKind operation_kind; + + /// Vector of operation name substrings + std::vector operation_names; + + /// Map of problems to run for each operation + /// [operation_name] -> vector of problems, each problem specified as a vector of [argument name] -> [argument value] + std::unordered_map> operation_problems; + + /// Vector of operation name substrings + std::vector excluded_operation_names; + + + // + // Detailed configuration options + // + + /// Configuration + CommandLine cmdline; + Device device; + Initialization initialization; + Library library; + Verification verification; + Profiling profiling; + Report report; + About about; + +public: + + explicit Options(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out) const; + + static std::string indent_str(int indent); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h new file mode 100644 index 0000000000000000000000000000000000000000..07102c99bc0f38a071e1ab828aab30678a3e2d44 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Class performing output during profiling +*/ + +#pragma once + +#include +#include + +// CUTLASS Profiler includes +#include "options.h" +#include "enumerated_types.h" +#include "performance_result.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +class PerformanceReport { +private: + + /// Reference to options + Options const &options_; + + /// Operation kind + library::OperationKind op_kind_; + + /// Operation file name containing performance report of op_kind + std::string op_file_name_; + + /// Output file containing results + std::ofstream output_file_; + + /// Operation file name containing junit performance report of op_kind + std::string op_junit_file_name_; + + /// Output file containing junit results + std::ofstream junit_output_file_; + + /// Flag indicating the performance report is valid + bool good_; + + /// Vector of argument names + std::vector argument_names_; + + /// Counter uniquely identifying problem within the report + size_t problem_index_; + + /// Collection of all results + PerformanceResultVector concatenated_results_; + +public: + + PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); + ~PerformanceReport(); + + bool good() const { return good_; } + + void next_problem(); + void append_result(PerformanceResult result); + void sort_flops_per_byte(PerformanceResultVector &results); + void sort_flops_per_sec(PerformanceResultVector &results); + void append_results(PerformanceResultVector const &results); + +public: + + /// Prints the CSV header + std::ostream & print_csv_header_(std::ostream &out); + + /// Prints the CSV + std::ostream & print_result_csv_(std::ostream &out, PerformanceResult const &result); + + /// @defgroup jUnit Result Generation + /// Functions related to generation of the jUnit results + /// @{ + + std::ostream & print_junit_header_(std::ostream &out); + std::ostream & print_junit_result_(std::ostream &out, PerformanceResult const &result); + std::ostream & print_junit_footer_(std::ostream &out); + + /// @} + + /// Prints the result in human readable form + std::ostream & print_result_pretty_( + std::ostream &out, + PerformanceResult const &result, + bool use_shell_coloring = true); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h new file mode 100644 index 0000000000000000000000000000000000000000..986ac89bc86a267ce8fb181a986f28f3f0936566 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h @@ -0,0 +1,137 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +// CUTLASS Profiler includes +#include "enumerated_types.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance result object +struct PerformanceResult { + + /// Index of problem + size_t problem_index; + + /// library::Provider + library::Provider provider; + + /// Operation kind + library::OperationKind op_kind; + + /// CUTLASS status result from kernels (success or failure) + // Status does information on verification + Status status; + + /// Outcome of verification (worst case verification result) + Disposition disposition; + + /// Outcome of verification (all verification results) + DispositionMap verification_map; + + /// Operation name + std::string operation_name; + + /// Stringified vector of argument values + std::vector > arguments; + + /// Number of bytes read or written + int64_t bytes; + + /// Number of DL flops performed by the math function + int64_t flops; + + /// Average runtime in ms + double runtime; + + /// Average runtime in ms per device + std::vector runtime_vector; + + // + // Members + // + + /// Ctor + PerformanceResult(): + problem_index(0), + op_kind(library::OperationKind::kInvalid), + provider(library::Provider::kInvalid), + disposition(Disposition::kNotRun), + status(Status::kInvalid), + bytes(0), + flops(0), + runtime(0) + { } + + // Copy constructor for deep copy + PerformanceResult(const PerformanceResult& other) = default; + + // Explicitly define copy assignment operator + PerformanceResult& operator=(const PerformanceResult& other) = default; + + /// Returns true if the runtime is valid + bool good() const { + return runtime > 0; + } + + /// Math throughput in units of GFLOP/s + double gflops_per_sec() const { + return double(flops) / runtime / 1.0e6; + } + + /// memory bandwidth in units of GiB/s + double gbytes_per_sec() const { + return double(bytes) / double(1 << 30) / runtime * 1000.0; + } + +}; + +using PerformanceResultVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h new file mode 100644 index 0000000000000000000000000000000000000000..9bdbec657c10cff0dafebd2cb6cd52057f3695c9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h @@ -0,0 +1,1039 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + + "Any sufficiently complicated C or Fortran program contains an ad-hoc, informally-specified, + bug-ridden, slow implementation of half of Common Lisp." + + - Greenspun's Tenth Rule of Programming + + + cutlass::profiler::ProblemSpace defines a set of data structures which represent the Cartesian + product of sequences defined by integer ranges, lists of scalars, and sets of enumerated types. + + These permit a single invocation of the CUTLASS Profiler to iterate over a large set of problems, + verify and profile various operations when they are compatible with the command line, and + construct data tables of results that are convenient inputs to post processing in Excel or Pandas. + + By executing multiple problems per invocation, startup overheads may be amortized across many + kernel launches. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// CUTLASS Utility includes +#include "cutlass/util/command_line.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines the argument schema +struct ArgumentDescription { + + /// Type of argument + ArgumentTypeID type; + + /// Prioritized array of aliases used in command line parsing + std::vector aliases; + + /// Description of argument + std::string description; + + // + // Methods + // + + /// Default ctor + ArgumentDescription(): + type(ArgumentTypeID::kInvalid) { } + + /// Constructor with aliases + ArgumentDescription( + ArgumentTypeID type_, + std::vector const &aliases_, + std::string const &description_ + ): + type(type_), aliases(aliases_), description(description_) { } +}; + +/// Vector of arguments +using ArgumentDescriptionVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for kernel arguments +struct KernelArgument { + + // + // Type definitions + // + + /// Value base class + struct Value { + + KernelArgument const *argument; + bool not_null; + + // + // Methods + // + + Value( + KernelArgument const *argument_ = nullptr, + bool not_null_ = true + ): argument(argument_), not_null(not_null_) { } + + virtual ~Value() { } + + virtual std::ostream &print(std::ostream &out) const =0; + }; + + /// Abstract base class to iterate over values within arguments + struct ValueIterator { + + /// Indicates type of kernel argument + KernelArgument const *argument; + + /// If the iterator points to an argument that is null, it needs to be distinguished + /// from end. + bool null_argument; + + // + // Methods + // + + /// Constructs a value iterator - no methods are valid if argument_ == nullptr + ValueIterator( + KernelArgument const *argument_ = nullptr, + bool null_argument_ = false): + argument(argument_), null_argument(null_argument_) { + + if (!argument_->not_null()) { + null_argument = true; + } + } + + virtual ~ValueIterator() { } + + /// Advances to next point in range + virtual void operator++() = 0; + + /// Compares against another value iterator - must be of the same KernelArgument type + virtual bool operator==(ValueIterator const &it) const = 0; + + /// Returns a unique_ptr object pointing to a newly created value object + virtual std::unique_ptr at() const = 0; + + /// Gets the type of the iterator + ArgumentTypeID type() const { + return argument->description->type; + } + + /// Helper to compute inequality + bool operator!=(ValueIterator const &it) const { + return !(*this == it); + } + + std::ostream &print(std::ostream &out) const; + }; + + // + // Data members + // + + /// Describes the argument + ArgumentDescription const *description; + + /// Parent node + KernelArgument *parent; + + /// Sequence in which the kernel argument is to be iterated over. + /// Smaller means faster changing. -1 is don't care + int ordinal; + + // + // Methods + // + + /// Default ctor + KernelArgument( + ArgumentDescription const *description_ = nullptr, + KernelArgument *parent_ = nullptr, + int ordinal_ = -1 + ): description(description_), parent(parent_), ordinal(ordinal_) { } + + virtual ~KernelArgument(); + + /// Returns true if the kernel argument iself is empty + virtual bool not_null() const =0; + + /// Returns a string name for debugging + std::string qualified_name() const { + if (description) { + if (description->aliases.empty()) { + return ""; + } + return description->aliases.front(); + } + return ""; + } + + virtual std::unique_ptr begin() const =0; + virtual std::unique_ptr end() const =0; +}; + +using KernelArgumentVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a scalar argument type as a string that is lexically cast to the appropriate kernel +/// type. +struct ScalarArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct ScalarValue : public KernelArgument::Value { + + std::string value; + + // + // Methods + // + + ScalarValue( + std::string const &value_ = "", + ScalarArgument const *argument = nullptr, + bool not_null_ = true + ); + + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct ScalarValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit ScalarValueIterator(ScalarArgument const *argument = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit ScalarArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Closed range supporting additive increment +struct Range { + + // + // Type definitions + // + + enum class Mode { + kSequence, + kRandom, + kRandomLog2, + kInvalid + }; + + struct Iterator { + + int64_t value; + int64_t increment; + Range const *range; + + // + // Methods + // + + Iterator( + int64_t value_ = 0, + int64_t increment_ = 1, + Range const *range_ = nullptr + ): + value(value_), increment(increment_), range(range_) { } + + Iterator & operator++() { + value += increment; + return *this; + } + + Iterator operator++(int) { + Iterator self(*this); + ++(*this); + return self; + } + + bool operator==(Iterator const &it) const { + return value == it.value; + } + + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + static int64_t round(int64_t value, int64_t divisible) { + int64_t rem = (value % divisible); + + // Round either up or down + if (rem > divisible / 2) { + value += (divisible - rem); + } + else { + value -= rem; + } + + return value; + } + + int64_t at() const { + if (!range) { + return value; + } + + switch (range->mode) { + case Mode::kSequence: return value; + + case Mode::kRandom: { + double rnd = double(range->minimum) + + double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); + + int64_t value = int64_t(rnd); + + return round(value, range->divisible); + } + break; + + case Mode::kRandomLog2: { + double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); + double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); + double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); + + int64_t value = int64_t(std::pow(2.0, rnd)); + + return round(value, range->divisible); + } + break; + default: break; + } + return value; + } + + int64_t operator*() const { + return at(); + } + }; + + // + // Data members + // + + int64_t first; ///< first element in range + int64_t last; ///< last element in range + int64_t increment; ///< additive increment between values + + Mode mode; ///< mode selection enables alternative values + int64_t minimum; ///< minimum value to return + int64_t maximum; ///< maximum value to return + int64_t divisible; ///< rounds value down to an integer multiple of this value + + // + // Methods + // + + /// Default constructor - range acts as a scalar + Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } + + /// Range acts as a range + Range( + int64_t first_, + int64_t last_, + int64_t increment_ = 1, + Mode mode_ = Mode::kSequence, + int64_t minimum_ = 0, + int64_t maximum_ = 0, + int64_t divisible_ = 1 + ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { + + // Helpers to avoid constructing invalid ranges + if (increment > 0) { + if (last < first) { + std::swap(last, first); + } + } + else if (increment < 0) { + if (first < last) { + std::swap(last, first); + } + } + else if (last != first) { + last = first; + increment = 1; + } + } + + /// Helper to construct a sequence range + static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { + return Range(first_, last_, increment_, Mode::kSequence); + } + + /// Helper to construct a range that is a random distribution + static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); + } + + /// Helper to construct a range that is a random distribution over a log scale + static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); + } + + /// Returns an iterator to the first element within the range + Iterator begin() const { + return Iterator(first, increment, this); + } + + /// Returns an iterator to the first element *after* the range + Iterator end() const { + return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); + } +}; + +/// Integer-valued argument - represented as a list of integer-valued ranges +struct IntegerArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct IntegerValue : public KernelArgument::Value { + + int64_t value; + + // + // Methods + // + + IntegerValue( + int64_t value_ = 0, + IntegerArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Collection of ranges represent the IntegerArgument's state + using RangeCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct IntegerValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + RangeCollection::const_iterator range_it; + Range::Iterator value_it; + + // + // Methods + // + + IntegerValueIterator(); + IntegerValueIterator(IntegerArgument const *argument); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + RangeCollection ranges; + + // + // Methods + // + + /// Default ctor + IntegerArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + bool _not_null = !ranges.empty(); + return _not_null; + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure defining the data type of tensors +struct TensorArgument : public KernelArgument { + + // + // Type definitions + // + + struct TensorDescription { + + /// Data type of elements + library::NumericTypeID element; + + /// Layout definition + library::LayoutTypeID layout; + + /// Computed extent + std::vector extent; + + /// Enables directly specifying stride value used to size tensor + std::vector stride; + + // + // Methods + // + + TensorDescription( + library::NumericTypeID element_ = library::NumericTypeID::kUnknown, + library::LayoutTypeID layout_ = library::LayoutTypeID::kUnknown, + std::vector extent_ = std::vector(), + std::vector stride_ = std::vector() + ): + element(element_), layout(layout_), extent(extent_), stride(stride_) {} + }; + + using ValueCollection = std::vector; + + /// Value structure + struct TensorValue : public KernelArgument::Value { + + TensorDescription desc; + + // + // Methods + // + + TensorValue( + TensorDescription const &desc_ = TensorDescription(), + TensorArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Abstract base class to iterate over values within arguments + struct TensorValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit TensorValueIterator(TensorArgument const *argument_); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit TensorArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Numeric data type +struct EnumeratedTypeArgument : public KernelArgument { + + // + // Type definitions + // + + struct EnumeratedTypeValue : public KernelArgument::Value { + + /// Data type of element + std::string element; + + // + // Methods + // + + EnumeratedTypeValue( + std::string const &element_ = std::string(), + EnumeratedTypeArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + ValueCollection values; + + // + // Members + // + + /// Default ctor + explicit EnumeratedTypeArgument(ArgumentDescription const *description): + KernelArgument(description) {} + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Object storing the space argument values +class ProblemSpace { +public: + + /// Tuple of arguments + using Problem = std::vector>; + + /// Type used to iterator over things + using IteratorVector = std::vector>; + + /// Iterates over points in the design space + class Iterator { + private: + + /// One iterator per argument + IteratorVector iterators; + + public: + + // + // Methods + // + + explicit Iterator(); + Iterator(ProblemSpace const &problem_space); + Iterator(Iterator &&it); + + // Rule of three + Iterator(Iterator const &) = delete; + Iterator &operator=(Iterator const &it) = delete; + ~Iterator() = default; + + /// Pre-increment - advances to next point in argument range + void operator++(); + + /// Gets the current argument value + Problem at() const; + + /// Moves iterator to end + void move_to_end(); + + /// Equality operator + bool operator==(Iterator const &it) const; + + /// Inequality operator + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + /// Helper to call at() method + Problem operator*() const { + return at(); + } + + /// Helper to print iterator state + std::ostream & print(std::ostream &out) const; + + private: + + /// Helper for recursively constructing iterators + void construct_(KernelArgument const *argument); + }; + +public: + + // + // Data members + // + + KernelArgumentVector arguments; + + /// Map of argument names to their position within the argument vector + std::unordered_map argument_index_map; + +public: + + // + // Methods + // + + /// Default ctor + ProblemSpace() = default; + + /// Constructs a problem space from a vector of arguments. This vector must outlive + /// the ProblemSpace object, which stores pointers to objects within the + /// ArgumentDescriptionVector. + ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline); + + Iterator begin() const; // returns an iterator to the first point in the range + Iterator end() const; // returns an iterator to the first point after the range + + /// Returns the index of an argument by name + size_t argument_index(char const *name) const; + + /// Gets all argument names as an ordered vector + std::vector argument_names() const; + + /// Returns the number of dimensions of the problem space + size_t rank() const { return arguments.size(); } + +private: + + /// Helper for recursively cloning + void clone_( + KernelArgumentVector &kernel_args, + ArgumentDescription const *arg_desc); + + /// Parses command line argument + void parse_( + KernelArgument *arg, + CommandLine const &cmdline); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexically casts an argument to an int if it is defined. Returns true if not null. +bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int64_t &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_bool(bool &bool_value, KernelArgument::Value const *value_ptr); + +bool arg_as_bool(bool &bool_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID(library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID( + library::NumericTypeID &numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID(library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID( + library::LayoutTypeID &layout_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID(library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID( + library::OpcodeClassID &opcode_class, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID( + library::SplitKMode &split_k_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID(library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID( + library::ConvModeID &conv_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID(library::IteratorAlgorithmID &iterator_algorithm, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID( + library::IteratorAlgorithmID &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype( + library::RuntimeDatatype &runtime_datatype, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder( + library::RasterOrder &raster_order, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID(library::Provider &provider, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID( + library::Provider &provider, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_string( + std::string& arg, + char const* name, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + TensorArgument::TensorValue const *value_ptr); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ba47a6832077984c334a5467257a151735b088b3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class Rank2KOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + Rank2KOperationProfiler(Options const &options); + + /// Destructor + virtual ~Rank2KOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..fff190a7570cd5811c6e5de6284bf96e40c404b7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h @@ -0,0 +1,227 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class RankKOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + RankKOperationProfiler(Options const &options); + + /// Destructor + virtual ~RankKOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..0c81ef4637175a6de1f44cedddf319436aaff24d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for reduction operation + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class ReductionOperationProfiler : public OperationProfiler { +public: + + + /// Workspace used + struct ReductionWorkspace { + + /// Conv device allocations + DeviceAllocation *Workspace; + DeviceAllocation *Source; + DeviceAllocation *Destination; + DeviceAllocation *Reference; + + /// Library configuration and arguments + library::ReductionConfiguration configuration; + library::ReductionArguments arguments; + + /// Buffer used for the cutlass operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + ReductionWorkspace(): + Workspace(nullptr), Source(nullptr), Destination(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// Reduction problem obtained from problem space + MatrixCoord problem_; + + /// Device memory allocations + ReductionWorkspace conv_workspace_; + + +public: + // + // Methods + // + + /// Ctor + ReductionOperationProfiler(Options const &options); + + /// Destructor + virtual ~ReductionOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..60204d8c9d458ab12020a6492de23174739aa584 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h @@ -0,0 +1,214 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "gemm_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class SparseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SparseGemmProblem { + int64_t m; + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t lde; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + static int const sparse = 2; + // every 128b ElementA uses one elementE + int elements_per_128b; + + // + // Methods + // + + SparseGemmProblem(): + m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SparseGemmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *E; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SparseGemmConfiguration configuration; + library::SparseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SparseGemmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + // GEMM problem + SparseGemmProblem problem_; + + /// Device memory allocations + SparseGemmWorkspace gemm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SparseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SparseGemmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..94ded5e803bf914e5ae8c4ebb867cfe42ef829bc --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class SymmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SymmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldc; + SideMode side_mode; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + SymmProblem(): + m(16), n(16), lda(0), ldb(0), ldc(0), + side_mode(SideMode::kInvalid), fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::SymmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::SymmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SymmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SymmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + SymmProblem problem_; + + /// Device memory allocations + SymmWorkspace symm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SymmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SymmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..9f21dafa0ecc869840fdba0a9c4414a89bbf4a7d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class TrmmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct TrmmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldd; + SideMode side_mode; + FillMode fill_mode; + DiagType diag_type; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + TrmmProblem(): + m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct TrmmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *D; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + TrmmWorkspace(): + A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + TrmmProblem problem_; + + /// Device memory allocations + TrmmWorkspace trmm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + TrmmOperationProfiler(Options const &options); + + /// Destructor + virtual ~TrmmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c2727c989e645eca8e67a5d8d50391ced803cffa --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +struct GPU_Clock +{ + GPU_Clock() { + cudaEventCreate(&start_); + cudaEventCreate(&stop_); + cudaEventRecord(start_); + } + + ~GPU_Clock() { + cudaEventDestroy(start_); + cudaEventDestroy(stop_); + } + + void start() { + cudaEventRecord(start_); + } + + float milliseconds() { + cudaEventRecord(stop_); + cudaEventSynchronize(stop_); + float time; + cudaEventElapsedTime(&time, start_, stop_); + return time; + } + + float seconds() { + return milliseconds() * float(1e-3); + } + + private: + cudaEvent_t start_, stop_; +}; diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h new file mode 100644 index 0000000000000000000000000000000000000000..c95bd1cbeb56cc566394b155ea7ac24f07c28162 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h @@ -0,0 +1,324 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * Utility for parsing command line arguments + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/****************************************************************************** + * command_line + ******************************************************************************/ + +/** + * Utility for parsing command line arguments + */ +struct CommandLine { + std::vector keys; + std::vector values; + std::vector args; + + /** + * Constructor + */ + CommandLine(int argc, const char** argv) { + using namespace std; + + for (int i = 1; i < argc; i++) { + string arg = argv[i]; + + if ((arg[0] != '-') || (arg[1] != '-')) { + args.push_back(arg); + continue; + } + + string::size_type pos; + string key, val; + if ((pos = arg.find('=')) == string::npos) { + key = string(arg, 2, arg.length() - 2); + val = ""; + } else { + key = string(arg, 2, pos - 2); + val = string(arg, pos + 1, arg.length() - 1); + } + + keys.push_back(key); + values.push_back(val); + } + } + + /** + * Constructor to represent a command line from a map of [argument] -> [value] + */ + CommandLine(std::unordered_map& arg_map) { + for (const auto& [key, value] : arg_map) { + keys.push_back(key); + values.push_back(value); + } + } + + /** + * Checks whether a flag "--" is present in the commandline + */ + bool check_cmd_line_flag(const char* arg_name) const { + using namespace std; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) return true; + } + return false; + } + + /** + * Returns number of naked (non-flag and non-key-value) commandline parameters + */ + size_t num_naked_args() const { + return args.size(); + } + + /** + * Print naked (non-flag and non-key-value) commandline parameters + */ + void print_naked_args(std::ostream &out) const { + for (auto arg : args) { + out << " " << arg <<"\n"; + } + } + + /** + * Returns the commandline parameter for a given index (not including flags) + */ + template + void get_cmd_line_argument(size_t index, value_t& val) const { + using namespace std; + if (index < args.size()) { + istringstream str_stream(args[index]); + str_stream >> val; + } + } + + /** + * Obtains the boolean value specified for a given commandline parameter --= + */ + void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { + val = _default; + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + val = !(value == "0" || value == "false"); + } + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val) const { + + get_cmd_line_argument(arg_name, val, val); + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val, + value_t const& _default) const { + using namespace std; + + val = _default; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) { + istringstream str_stream(values[i]); + str_stream >> val; + } + } + } + + /** + * Returns the values specified for a given commandline parameter --=,* + */ + template + void get_cmd_line_arguments(const char* arg_name, + std::vector& vals, + char sep = ',') const { + using namespace std; + + if (check_cmd_line_flag(arg_name)) { + // Clear any default values + vals.clear(); + + // Recover from multi-value string + for (size_t i = 0; i < keys.size(); ++i) { + if (keys[i] == string(arg_name)) { + string val_string(values[i]); + separate_string(val_string, vals, sep); + } + } + } + } + + /** + * Returns the values specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_pairs(const char* arg_name, + std::vector >& tokens, + char delim = ',', + char sep = ':') const { + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + tokenize(tokens, value, delim, sep); + } + } + + /** + * Returns a list of ranges specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_ranges(const char* arg_name, + std::vector >& vals, + char delim = ',', + char sep = ':') const { + std::vector ranges; + get_cmd_line_arguments(arg_name, ranges, delim); + + for (std::vector::const_iterator range = ranges.begin(); + range != ranges.end(); ++range) { + + std::vector range_vals; + separate_string(*range, range_vals, sep); + vals.push_back(range_vals); + } + } + + /** + * The number of pairs parsed + */ + int parsed_argc() const { return (int)keys.size(); } + + //------------------------------------------------------------------------- + // Utility functions + //------------------------------------------------------------------------- + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector >& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + // Home-built to avoid Boost dependency + size_t s_idx = 0; + size_t d_idx = std::string::npos; + while (s_idx < str.size()) { + d_idx = str.find_first_of(delim, s_idx); + + size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); + size_t sep_idx = str.find_first_of(sep, s_idx); + size_t offset = 1; + if (sep_idx == std::string::npos || sep_idx >= end_idx) { + sep_idx = end_idx; + offset = 0; + } + + std::pair item( + str.substr(s_idx, sep_idx - s_idx), + str.substr(sep_idx + offset, end_idx - sep_idx - offset)); + + tokens.push_back(item); + s_idx = end_idx + 1; + } + } + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + typedef std::vector > TokenVector; + typedef TokenVector::const_iterator token_iterator; + + std::vector > token_pairs; + tokenize(token_pairs, str, delim, sep); + for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { + tokens.push_back(tok->first); + } + } + + template + static void separate_string(std::string const& str, + std::vector& vals, + char sep = ',') { + std::istringstream str_stream(str); + std::string::size_type old_pos = 0; + std::string::size_type new_pos = 0; + + // Iterate -delimited values + value_t val; + while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { + if (new_pos != old_pos) { + str_stream.width(new_pos - old_pos); + str_stream >> val; + vals.push_back(val); + } + + // skip over delimiter + str_stream.ignore(1); + old_pos = new_pos + 1; + } + + // Read last value + str_stream >> val; + vals.push_back(val); + } +}; + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8ace1e0a232ea7cccbb2089ec8432783c49410dd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp @@ -0,0 +1,528 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +//-- BLAM_DEBUG_OUT --------------------------------------------------------- +#ifdef BLAM_DEBUG +# include +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl +# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl +# endif // BLAM_DEBUG_OUT +#else +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) +# define BLAM_DEBUG_OUT_2(msg) +# endif // BLAM_DEBUG_OUT +#endif // BLAM_DEBUG + +// User could potentially define ComplexFloat/ComplexDouble instead of std:: +#ifndef BLAM_COMPLEX_TYPES +#define BLAM_COMPLEX_TYPES 1 +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(complex) + +namespace blam { +template +using Complex = cuda::std::complex; +using ComplexFloat = cuda::std::complex; +using ComplexDouble = cuda::std::complex; +} +#endif // BLAM_COMPLEX_TYPES + +// User could potentially define Half instead of cute:: +#ifndef BLAM_HALF_TYPE +#define BLAM_HALF_TYPE 1 +#include +namespace blam { +using Half = cute::half_t; +} +#endif // BLAM_HALF_TYPE + +namespace blam +{ +namespace cublas +{ + +inline const char* +cublas_get_error(cublasStatus_t status) +{ + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; + default: + return "CUBLAS_ERROR -- "; + } +} + +inline bool +cublas_is_error(cublasStatus_t status) +{ + return status != CUBLAS_STATUS_SUCCESS; +} + + +// hgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const Half* beta, + Half* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasHgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + reinterpret_cast(beta), + reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, + CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// mixed hf gemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + beta, + C, CUDA_R_32F, ldC, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// igemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const int32_t* alpha, + const int8_t* A, int ldA, + const int8_t* B, int ldB, + const int32_t* beta, + int32_t* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasIgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + A, CUDA_R_8I, ldA, + B, CUDA_R_8I, ldB, + beta, + C, CUDA_R_32I, ldC, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// sgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, + const float* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasSgemm"); + + return cublasSgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// dgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, + const double* B, int ldB, + const double* beta, + double* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasDgemm"); + + return cublasDgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// cgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, + const ComplexFloat* B, int ldB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasCgemm"); + + return cublasCgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// zgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, + const ComplexDouble* B, int ldB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasZgemm"); + + return cublasZgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, int loA, + const Half* B, int ldB, int loB, + const Half* beta, + Half* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); + + return cublasHgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast<__half*>(C), ldC, loC, + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, int loA, + const float* B, int ldB, int loB, + const float* beta, + float* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); + + return cublasSgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, int loA, + const double* B, int ldB, int loB, + const double* beta, + double* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); + + return cublasDgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, int loA, + const ComplexFloat* B, int ldB, int loB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); + + return cublasCgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, int loA, + const ComplexDouble* B, int ldB, int loB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); + + return cublasZgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* const A[], int ldA, + const Half* const B[], int ldB, + const Half* beta, + Half* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmBatched"); + + return cublasHgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), ldA, + // A, ldA, // cuBLAS 9.2 + reinterpret_cast(const_cast(B)), ldB, + // B, ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + reinterpret_cast<__half**>(const_cast(C)), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* const A[], int ldA, + const float* const B[], int ldB, + const float* beta, + float* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmBatched"); + + return cublasSgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* const A[], int ldA, + const double* const B[], int ldB, + const double* beta, + double* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmBatched"); + + return cublasDgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* const A[], int ldA, + const ComplexFloat* const B[], int ldB, + const ComplexFloat* beta, + ComplexFloat* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmBatched"); + + return cublasCgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* const A[], int ldA, + const ComplexDouble* const B[], int ldB, + const ComplexDouble* beta, + ComplexDouble* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmBatched"); + + return cublasZgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +} // end namespace cublas +} // end namespace blam diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..88481a82e0e08f06b54c07c946d28160d41f9f07 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Contains code for debugging cutlass code +*/ + +#pragma once + +#include "device_dump.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/****************************************************************************** + * Debug and logging macros + ******************************************************************************/ + +/** + * Formats and prints the given message to stdout + */ +#if !defined(CUDA_LOG) +#if !defined(__CUDA_ARCH__) +#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) +#else +#define CUDA_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, \ + blockIdx.y, \ + blockIdx.z, \ + threadIdx.x, \ + threadIdx.y, \ + threadIdx.z, \ + __VA_ARGS__); +#endif +#endif + +/** + * Formats and prints the given message to stdout only if DEBUG is defined + */ +#if !defined(CUDA_LOG_DEBUG) +#ifdef DEBUG +#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) +#else +#define CUDA_LOG_DEBUG(format, ...) +#endif +#endif + +/** + * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) + * along with the supplied source context. + * + * \return The CUDA error. + */ +__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, + const char* expression, + const char* filename, + int line) { + (void)filename; + (void)line; + if (error) { +#if !defined(__CUDA_ARCH__) + fprintf( + stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error)); + fflush(stderr); +#else + printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression); +#endif + } + return error; +} + +/** + * \brief Perror macro + */ +#ifndef CUDA_PERROR +#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__) +#endif + +/** + * \brief Perror macro with exit + */ +#ifndef CUDA_PERROR_EXIT +#define CUDA_PERROR_EXIT(e) \ + do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \ + exit(1); \ + } } while (0) +#endif + +/** + * \brief Perror macro only if DEBUG is defined + */ +#ifndef CUDA_PERROR_DEBUG +#ifdef DEBUG +#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) +#else +#define CUDA_PERROR_DEBUG(e) (e) +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// A small helper class to dump a type at compile time +// Usage:: DumpType::Class +template +struct DebugType {}; + +template +void DebugTypeFunc(T const& t) { + T::t; +} + +// A small helper class to dump a compile time constant at compile time +// Usage: DumpValue::kConstant +template +struct DebugValue {}; diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h new file mode 100644 index 0000000000000000000000000000000000000000..a73a8cfe79dd22c2d298fcb3be8cf25d5e3f5734 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +/** + * \file + * \brief C++ interface to dump fragments and shared memory contents for + * debugging. + */ + +namespace cutlass { +namespace debug { + +/****************************************************************************** + * Dump the fragments + ******************************************************************************/ + +/// The first N threads dump the first M elements from their fragments with a +/// stride of S elements. If N is not specified, dump the data of all the +/// threads. If M is not specified, dump all the elements of the fragment. +template +CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0, + int S = 1) { + int total_threads = blockDim.x * blockDim.y * blockDim.z; + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (N < 0 || N > total_threads) { + if (thread_id == 0 && block_id == 0) + printf("Thread number N = %d should between [1, %d].\n", N, + total_threads); + + __syncthreads(); + + return; + } + + int total_elements = int(frag.size()); + + if (M < 0 || M > total_elements) { + if (thread_id == 0 && block_id == 0) + printf("Element number M = %d should between [1, %d].\n", M, + total_elements); + + __syncthreads(); + + return; + } + + if (N == 0) N = total_threads; + + if (M == 0) M = total_elements; + + if (S < 1 || S > M) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, M); + + __syncthreads(); + + return; + } + + if (thread_id == 0 && block_id == 0) + printf("\n*******************Dumping the fragments*******************\n\n"); + + CUTLASS_PRAGMA_NO_UNROLL + for (int tid = 0; tid < N; ++tid) { + if (tid == thread_id) { + printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31); + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < M; i += S) { + printf("%.0f ", float(typename Fragment::value_type(frag[i]))); + } + printf("\n"); + } + + __syncthreads(); + } + + if (thread_id == 0 && block_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} + +/****************************************************************************** + * Dump the shared memory + ******************************************************************************/ + +#define SHMEM_ROW_SIZE 128 + +/// Dump the shared memory contents. ptr is the begin address, size specifies +/// the number of elements that need to be dumped, and S specifies the stride. +template +CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) { + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (ptr == nullptr) { + if (thread_id == 0 && block_id == 0) printf("ptr is null.\n"); + + __syncthreads(); + return; + } + + if (size < 1) { + if (thread_id == 0 && block_id == 0) + printf("Element size is less than 1\n"); + + __syncthreads(); + + return; + } + + int row_elements = SHMEM_ROW_SIZE / sizeof(Element); + + if (S < 1 || S > row_elements) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, row_elements); + + __syncthreads(); + + return; + } + + __syncthreads(); + + if (thread_id == 0) + printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id); + + if (thread_id == 0) { + for (int i = 0; i < size; i += row_elements) { + for (int j = 0; j < row_elements; j += S) { + printf("%.0f ", float(ptr[i + j])); + } + + printf("\n"); + } + } + + if (thread_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} +} // namespace debug +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..59457b2e8122f46e443844fe276b2c7fb35f3f56 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h @@ -0,0 +1,402 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C']. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do group norm on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +extern __shared__ char groupnorm_shm[]; + +// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory, +// we store the input in the shared memory. +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_store_locally(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid; + float local_sum[1] = {0.0f}; + +// load from global memory into shared memory +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + local_val[local_val_offset + j] = tmp_vec_ptr[j]; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(local_val[local_val_offset + j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + const int local_val_offset = i * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = (static_cast(local_val[local_val_offset + j]) - s_mean) * s_variance + * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = tmp_vec; + } + } +} + +// For large prod_dim1_to_last_dim/num_groups, +// in which the data cannot be stored locally, +// we will load from global memory multiple times, +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_multiple_load(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + float local_sum[1] = {0.0f}; + +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + TVec output_tmp_vec; + T* output_tmp_vec_ptr = (T*)(&output_tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = + (static_cast(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + output_tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + output_tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = output_tmp_vec; + } + } +} + +//ref_input & ref_output should be [N, H, W, C] +//ref_gamma & ref_beta should be [1, 1, 1, C] +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int N = input_size.n(); + const int H = input_size.h(); + const int W = input_size.w(); + const int C = input_size.c(); + if (C % num_groups != 0){ + printf("[ERROR] C should be a multiple of num_groups.\n"); + } + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + + const int dim0 = N; + const int last_dim = C; + const int prod_dim1_to_last_dim = H*W*C; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int s_group_stride = last_dim / num_groups; + dim3 grid(num_groups, dim0); + int threadblock_size = 32; + if (s_group_stride % 2 == 0) { + const int T_PER_TVec = 2; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32; + // the size of grid & block may have better choice for different cases. + // ensure shared memory is smaller than 48KB + if (std::is_same::value){ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + else{ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + } + else { + const int T_PER_TVec = 1; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0fcbf5cb0f4bf3152a708c6e3845e89fd214cfac --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h @@ -0,0 +1,644 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. + * \tparam T: data type + */ +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + T local_val[ITEM_PER_THREAD]; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + const T zero = T(0.0f); + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + local_val[i] = index < n ? input[index] : zero; + local_sums[0] += static_cast(local_val[i]); + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n){ + const float tmp = static_cast(local_val[i]) - s_mean; + local_sums[0] += tmp * tmp; + } + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n) { + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T2 local_val[ITEM_PER_THREAD]; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + const T2 zero = {T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_2 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const float2 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, + const T4* input, + const T4* gamma, + const T4* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T4 local_val[ITEM_PER_THREAD]; + const int n_4 = n / 4; + int offset = m_idx * n_4; + input += offset; + output += offset; + + const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_4 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + + static_cast(local_val[i].z) + static_cast(local_val[i].w); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const float4 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean, + static_cast(local_val[i].z) - s_mean, + static_cast(local_val[i].w) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const T4 gamma_val = gamma[index]; + const T4 beta_val = beta[index]; + T4 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); + tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_val = local_val - s_mean; + local_sums[0] += local_val * local_val; + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + const T local_val = input[index]; + output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const float2 tmp = {static_cast(local_val.x) - s_mean, + static_cast(local_val.y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } +} + +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + dim3 grid(m); + dim3 block((n + 31)/32*32); + if (block.x > 1024){ + block.x = 1024; + } + // TODO : There should be better configs for different cases, we only use several samples to show how to use here + // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. + if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { + block.x = (n/4 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (float4*)output, + (const float4*)input, + (const float4*)gamma, + (const float4*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (half4*)output, + (const half4*)input, + (const half4*)gamma, + (const half4*)beta, + m, + n); + } + } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) + else if (n % 2 == 0) { + if (n / 2 <= 1024) { + block.x = (n/2 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } //if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n / 2 <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/ 16 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 32768) + else { + if (block.x > 512) + block.x = 512; + if (std::is_same::value) { + layernorm_twoPassAlgo_e2<<>>( + (float2 *)output, + (const float2 *)input, + (const float2 *)gamma, + (const float2 *)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_e2<<>>( + (half2 *)output, + (const half2 *)input, + (const half2 *)gamma, + (const half2 *)beta, + m, + n); + } + } + } // if (n % 2 == 0) + else { + if (n <= 1024) { + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/16 + 32)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 32768) + else{ + if (block.x > 512) { + block.x = 512; + } + layernorm_twoPassAlgo_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } + } +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..44f6a467a5d0938289e4bc127cddc13b9aeabdf3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h @@ -0,0 +1,375 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ interface to CUDA device memory management functions. + */ + +#include +#include + +#include "cutlass/platform/platform.h" +#include "cutlass/numeric_types.h" +#include "cutlass/trace.h" +#include "exceptions.h" + +namespace cutlass { +namespace device_memory { + +/****************************************************************************** + * Allocation lifetime + ******************************************************************************/ + +/// Allocate a buffer of \p count elements of type \p T on the current CUDA device +template +T* allocate(size_t count = 1) { + + T* ptr = 0; + size_t bytes = count * sizeof_bits::value / 8; + + cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); + + if (cuda_error != cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 0) + std::ostringstream os; + os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); +#endif + throw cuda_exception("Failed to allocate memory", cuda_error); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + std::ostringstream os; + os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); + } +#endif + + return ptr; +} + +/// Free the buffer pointed to by \p ptr +template +void free(T* ptr) { + if (ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + throw cuda_exception("Failed to free device memory", cuda_error); + } + } +} + +/****************************************************************************** + * Data movement + ******************************************************************************/ + +template +void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { + size_t bytes = count * sizeof_bits::value / 8; + if (bytes == 0 && count > 0) { + bytes = 1; + } + cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); + if (cuda_error != cudaSuccess) { + std::ostringstream os; + os << "cutlass::device_memory::copy: cudaMemcpy() failed: " + << "dst=" << dst << ", src=" << src + << ", bytes=" << bytes << ", count=" << count; + if (kind == cudaMemcpyHostToDevice) { + os << ", kind=cudaMemcpyHostToDevice"; + } + else if (kind == cudaMemcpyDeviceToHost) { + os << ", kind=cudaMemcpyDeviceToHost"; + } + else if (kind == cudaMemcpyDeviceToDevice) { + os << ", kind=cudaMemcpyDeviceToDevice"; + } + else if (kind == cudaMemcpyHostToHost) { + os << ", kind=cudaMemcpyHostToHost"; + } + else if (kind == cudaMemcpyDefault) { + os << ", kind=cudaMemcpyDefault"; + } + else { + os << ", kind=Unknown"; + } + os << ", error: " << cudaGetErrorString(cuda_error); + + throw cuda_exception(os.str().c_str(), cuda_error); + } +} + +template +void copy_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToDevice); +} + +template +void copy_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToHost); +} + +template +void copy_device_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToDevice); +} + +template +void copy_host_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToHost); +} + +/// Copies elements from device memory to host-side range +template +void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { + size_t elements = end - begin; + copy_to_host(&*begin, device_begin, elements); +} + +/// Copies elements to device memory from host-side range +template +void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { + size_t elements = end - begin; + copy_to_device(device_begin, &*begin, elements); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DeviceAllocation { +public: + + /// Delete functor for CUDA device memory + struct deleter { + void operator()(T* ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + // noexcept + // throw cuda_exception("cudaFree() failed", cuda_error); + return; + } + } + }; + +public: + // + // Data members + // + + /// Number of elements of T allocated on the current CUDA device + size_t capacity; + + /// Smart pointer + platform::unique_ptr smart_ptr; + +public: + + // + // Static methods + // + + /// Static member to compute the number of bytes needed for a given number of elements + static size_t bytes(size_t elements) { + if (sizeof_bits::value < 8) { + size_t const kElementsPerByte = 8 / sizeof_bits::value; + return elements / kElementsPerByte; + } + else { + size_t const kBytesPerElement = sizeof_bits::value / 8; + return elements * kBytesPerElement; + } + } + +public: + + // + // Methods + // + + /// Constructor: allocates no memory + DeviceAllocation() : capacity(0) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device + DeviceAllocation(size_t _capacity) : + smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation + DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} + + /// Copy constructor + DeviceAllocation(DeviceAllocation const &p): + smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { + + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + } + + /// Move constructor + DeviceAllocation(DeviceAllocation &&p): capacity(0) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + } + + /// Destructor + ~DeviceAllocation() { reset(); } + + /// Returns a pointer to the managed object + T* get() const { return smart_ptr.get(); } + + /// Releases the ownership of the managed object (without deleting) and resets capacity to zero + T* release() { + capacity = 0; + return smart_ptr.release(); + } + + /// Deletes the managed object and resets capacity to zero + void reset() { + capacity = 0; + smart_ptr.reset(); + } + + /// Deletes managed object, if owned, and allocates a new object + void reset(size_t _capacity) { + reset(device_memory::allocate(_capacity), _capacity); + } + + /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity + void reset(T* _ptr, size_t _capacity) { + smart_ptr.reset(_ptr); + capacity = _capacity; + } + + /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. + void reallocate(size_t new_capacity) { + + platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); + + device_memory::copy_device_to_device( + new_allocation.get(), + smart_ptr.get(), + std::min(new_capacity, capacity)); + + std::swap(smart_ptr, new_allocation); + std::swap(new_capacity, capacity); + } + + /// Returns the number of elements + size_t size() const { + return capacity; + } + + /// Returns the number of bytes needed to store the allocation + size_t bytes() const { + return bytes(capacity); + } + + /// Returns a pointer to the object owned by *this + T* operator->() const { return smart_ptr.get(); } + + /// Returns the deleter object which would be used for destruction of the managed object. + deleter& get_deleter() { return smart_ptr.get_deleter(); } + + /// Returns the deleter object which would be used for destruction of the managed object (const) + const deleter& get_deleter() const { return smart_ptr.get_deleter(); } + + /// Copies a device-side memory allocation + DeviceAllocation & operator=(DeviceAllocation const &p) { + if (capacity != p.capacity) { + smart_ptr.reset(device_memory::allocate(p.capacity)); + capacity = p.capacity; + } + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + return *this; + } + + /// Move assignment + DeviceAllocation & operator=(DeviceAllocation && p) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + return *this; + } + + /// Copies the entire allocation from another location in device memory. + void copy_from_device(T const *ptr) const { + copy_from_device(ptr, capacity); + } + + /// Copies a given number of elements from device memory + void copy_from_device(T const *ptr, size_t elements) const { + device_memory::copy_device_to_device(get(), ptr, elements); + } + + void copy_to_device(T *ptr) const { + copy_to_device(ptr, capacity); + } + + void copy_to_device(T *ptr, size_t elements) const { + device_memory::copy_device_to_device(ptr, get(), elements); + } + + void copy_from_host(T const *ptr) const { + copy_from_host(ptr, capacity); + } + + void copy_from_host(T const *ptr, size_t elements) const { + device_memory::copy_to_device(get(), ptr, elements); + } + + void copy_to_host(T *ptr) const { + copy_to_host(ptr, capacity); + } + + void copy_to_host(T *ptr, size_t elements) const { + device_memory::copy_to_host(ptr, get(), elements); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace device_memory { + +/// Device allocation abstraction that tracks size and capacity +template +using allocation = cutlass::DeviceAllocation; + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h new file mode 100644 index 0000000000000000000000000000000000000000..8e38029951d27c0be8da059b59d2a83fe2762ef1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout. + * \tparam T: data type + */ +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + +template +__global__ void nchw_to_nhwc_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + const int hw = h*w; + const int chw = c*hw; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t ci0 = blockIdx.y * 32; + const int32_t hwi0 = blockIdx.x * 32; + + const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0; + const T *A = input + input_idx; + if (hwi0 + lid < hw) { + const int lid_x_33 = lid * 33; + if ((ci0 + 32) <= c) { + int ci = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + ci] = A[lid]; + A = &A[8 * hw]; + ci += 8; + } + } else { + for (int ci = wid; ci < 32; ci += 8) { + if ((ci + ci0) < c) { + shbuf[lid_x_33 + ci] = A[lid]; + } + A = &A[8 * hw]; + } + } + } + __syncthreads(); + + const int32_t ciOut = ci0 + lid; + output = &output[ni * chw + ciOut]; + if (ciOut < c) { + if (hwi0 + 32 < hw) { + int hwI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + hwI += 8; + } + } else { + for (int hwI = wid; hwI < 32; hwI += 8) { + if (hwi0 + hwI < hw) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + } + } + } + } +} + +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.h() && + input_tensor_size.h() == output_tensor_size.w() && + input_tensor_size.w() == output_tensor_size.c()); + + int n = output_tensor_size.n(); + int h = output_tensor_size.h(); + int w = output_tensor_size.w(); + int c = output_tensor_size.c(); + + dim3 grid((h*w + 31)/32, (c + 31)/32, n); + dim3 block(32, 8); + nchw_to_nhwc_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h new file mode 100644 index 0000000000000000000000000000000000000000..f58da62a35350b4a865f4521ec1cbb76ae87e874 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h @@ -0,0 +1,276 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels for padding in device memory with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface for padding in a device memory tensor with NHWC layout + * \tparam T: data type + */ +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_padding_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const int32_t c_in, + const int32_t c_out, + const T zero, + const T *input, + T *output){ + + const int32_t idx_jump = blockDim.x * gridDim.x; + const int32_t total_elements = n * h * w * c_out; + + int32_t c_idx, w_idx, h_idx, n_idx, resudial; + + T value; + for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) { + + c_idx = idx%c_out; + if (c_idx >= c_in){ + value = zero; + } + else{ + resudial = idx/c_out; + w_idx = resudial%w; + resudial = resudial/w; + h_idx = resudial%h; + n_idx = resudial/h; + resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx; + value = input[resudial]; + } + output[idx] = value; + } +} + + +// fast kernel for c_in = 3 & c_out = 4 +template +__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 256; + const int lower_bound = max_output_element < output_offset + 256 ? max_output_element : output_offset + 256; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4; + Telement array[element_in_Tio]; + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k]; + output[i] = *((const Tio *)array); + } +} + +// fast kernel for c_in = 3 & c_out = 8 +template +__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 512; + const int lower_bound = max_output_element < output_offset + 512 ? max_output_element : output_offset + 512; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3; + Telement array[element_in_Tio]; + //float + if (element_in_Tio == 4){ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]); + } + //half + else{ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = (k >= 3) ? zero_element : shm_element[k]; + } + output[i] = *((const Tio *)array); + } +} + +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream){ + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.h() && + input_tensor_size.w() == output_tensor_size.w() && + input_tensor_size.c() <= output_tensor_size.c()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c_in = input_tensor_size.c(); + int c_out = output_tensor_size.c(); + + //case 1 : channel == 3 padding to 4 or 8 + if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){ + dim3 block(192); + const int nhw = n*h*w; + const int nhwc = nhw*c_in; + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const int element_in_Tio = 8; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const int4 zero_io = {0, 0, 0, 0}; + const half_t zero_element = static_cast(0.0f); + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + //for float + else{ + const int element_in_Tio = 4; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f}; + const float zero_element = 0.0f; + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + } + //case 2 : even channel + else if ((c_out % 2) == 0 && (c_in % 2) == 0){ + int32_t total_elements = n * h * w * c_out / 2; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const __half2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data()); + } + //for float + else{ + const float2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data()); + } + } + //case 3 : odd channel + else{ + int32_t total_elements = n * h * w * c_out; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + const T zero = static_cast(0.0f); + nhwc_padding_kernel<<>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data()); + } +} + + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..5633456c1412ff41366ec4c6ec5c3e6e3a2d6c19 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h @@ -0,0 +1,573 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::MatrixCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream); + +/** get the output size of pooling + */ +inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) +{ + return (H_W + 2 * padding - kernel_size) / stride + 1; +} + +/** + * input is [N, H, W, C] + * assume stride == kernel_size + * output_h = (H + 2*padding_H - kernel_H)/stride_H + * output_w = (W + 2*padding_W - kernel_W)/stride_W + * output is [N, output_h, output_w, C] + * grid(N, output_h, output_w) + * block(min(C, 256)) : + * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) +*/ +template +__global__ void pooling_nhwc_element1_kernel(T* output, + const T* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float pooling; + if (IS_AVG_POOLING){ + pooling = 0.0f; + } + else{ + pooling = -FLT_MAX; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const float tmp = static_cast(input[idx + c_idx]); + if (IS_AVG_POOLING){ + pooling = pooling + tmp; + } + else{ + pooling = pooling > tmp ? pooling : tmp; + } + } + } + + T output_val; + if (IS_AVG_POOLING){ + output_val = T(pooling/kernel_size2); + } + else{ + output_val = T(pooling); + } + output[c_idx] = output_val; + } +} + +template +__global__ void pooling_nhwc_element2_kernel(T2* output, + const T2* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float2 pooling; + if (IS_AVG_POOLING) { + pooling = {0.0f, 0.0f}; + } + else { + pooling = {-FLT_MAX, -FLT_MAX}; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const T2 tmp = input[idx + c_idx]; + const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; + if (IS_AVG_POOLING) { + pooling.x += tmp_flt2.x; + pooling.y += tmp_flt2.y; + } + else { + pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; + pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; + } + } + } + + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling.x/kernel_size2); + output_val.y = T(pooling.y/kernel_size2); + } + else { + output_val.x = T(pooling.x); + output_val.y = T(pooling.y); + } + output[c_idx] = output_val; + } +} + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C, N) + * block(block_size) -- each block deals with H*W/block_size elements; +*/ +template +__global__ void pooling_nxhTo1x1_element1_kernel( + T* output, const T* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[1]; + if (IS_AVG_POOLING) { + pooling[0] = 0.0f; + } + else { + pooling[0] = -FLT_MAX; + } + const size_t input_offset = n_idx * HW * C + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + float val = static_cast(input[index * C]); + if (IS_AVG_POOLING) { + pooling[0] += val; + } + else { + pooling[0] = pooling[0] > val ? pooling[0] : val; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T output_val; + if (IS_AVG_POOLING) { + output_val = T(pooling[0] / HW); + } + else { + output_val = T(pooling[0]); + } + output[0] = output_val; + } +} + + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C/2, N) + * block(block_size) -- each thread deals with H*W/block_size * 2 elements; +*/ +template +__global__ void pooling_nxhTo1x1_element2_kernel( + T2* output, const T2* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[2]; + if (IS_AVG_POOLING) { + pooling[0] = pooling[1] = 0.0f; + } + else { + pooling[0] = pooling[1] = -FLT_MAX; + } + const int C_2 = C / 2; + const size_t input_offset = n_idx * HW * C_2 + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C_2 + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + T2 val = input[index * C_2]; + float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; + if (IS_AVG_POOLING) { + pooling[0] += val_flt2.x; + pooling[1] += val_flt2.y; + } + else { + pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; + pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling[0] / HW); + output_val.y = T(pooling[1] / HW); + } + else { + output_val.x = T(pooling[0]); + output_val.y = T(pooling[1]); + } + output[0] = output_val; + } +} + +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::Tensor4DCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream) { + + assert(input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.c()); + + const int N = input_tensor_size.n(); + const int H = input_tensor_size.h(); + const int W = input_tensor_size.w(); + const int C = input_tensor_size.c(); + const int padding_H = padding.h(); + const int padding_W = padding.w(); + const int kernel_H = filter_tensor_size.h(); + const int kernel_W = filter_tensor_size.w(); + const int stride_H = stride.row(); + const int stride_W = stride.column(); + + const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); + const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); + + assert(output_tensor_size.h() == output_H && + output_tensor_size.w() == output_W); + + if (C % 2 != 0) { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } // if (poolingType == 0) + else { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C < block.x) { + block.x = C; + } + if (poolingType == 0) { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (poolingType == 0) + else { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } // if (C % 2 != 0)) + else { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C/2, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C/2 < block.x) { + block.x = C/2; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } + } +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h new file mode 100644 index 0000000000000000000000000000000000000000..babfecd39205ebff39794133868e4a95b7e9525c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout. + * \tparam T: data type + */ +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_to_nchw_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + + const int hw = h*w; + const int hwc = hw*c; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t hwi0 = blockIdx.y * 32; + const int32_t ci0 = blockIdx.x * 32; + + const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; + const T *A = input + input_idx; + if (ci0 + lid < c) { + const int lid_x_33 = lid * 33; + if ((hwi0 + 32) <= hw) { + int hwi = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + hwi] = A[lid]; + A = &A[8 * c]; + hwi += 8; + } + } else { + for (int hwi = wid; hwi < 32; hwi += 8) { + if ((hwi + hwi0) < hw) { + shbuf[lid_x_33 + hwi] = A[lid]; + } + A = &A[8 * c]; + } + } + } + __syncthreads(); + + const int32_t hwiOut = hwi0 + lid; + output = &output[ni * hwc + hwiOut]; + if (hwiOut < hw) { + if (ci0 + 32 < c) { + int cI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + cI += 8; + } + } else { + for (int cI = wid; cI < 32; cI += 8) { + if (ci0 + cI < c) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + } + } + } + } +} + +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.c() && + input_tensor_size.w() == output_tensor_size.h() && + input_tensor_size.c() == output_tensor_size.w()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c = input_tensor_size.c(); + + dim3 grid((c + 31)/32, (h*w + 31)/32, n); + dim3 block(32, 8); + nhwc_to_nchw_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); + +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0d1b1af56e4463640edc3e9c82533baf815c9b27 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h @@ -0,0 +1,186 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/device_utils.h" +#include + +namespace cutlass { + +__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, + const float4 *weight, + const int m, const int n, float epsilon) { + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + const int n_8 = n / 8; + int offset = m_idx * n_8; + input += offset; + output += offset; + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const half2 *h1 = (half2 *)&local_val.x; + const half2 *h2 = (half2 *)&local_val.y; + const half2 *h3 = (half2 *)&local_val.z; + const half2 *h4 = (half2 *)&local_val.w; + local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + + static_cast(h1->y) * static_cast(h1->y) + + static_cast(h2->x) * static_cast(h2->x) + + static_cast(h2->y) * static_cast(h2->y) + + static_cast(h3->x) * static_cast(h3->x) + + static_cast(h3->y) * static_cast(h3->y) + + static_cast(h4->x) * static_cast(h4->x) + + static_cast(h4->y) * static_cast(h4->y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const float4 weight_val = weight[index]; + + const half2 *l1 = (half2 *)&local_val.x; + const half2 *l2 = (half2 *)&local_val.y; + const half2 *l3 = (half2 *)&local_val.z; + const half2 *l4 = (half2 *)&local_val.w; + + const half2 *g1 = (half2 *)&weight_val.x; + const half2 *g2 = (half2 *)&weight_val.y; + const half2 *g3 = (half2 *)&weight_val.z; + const half2 *g4 = (half2 *)&weight_val.w; + + float4 tmp; + half2 *h1 = (half2 *)&tmp.x; + half2 *h2 = (half2 *)&tmp.y; + half2 *h3 = (half2 *)&tmp.z; + half2 *h4 = (half2 *)&tmp.w; + + h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); + h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y)); + h2->x = half(static_cast(l2->x) * s_mean * static_cast(g2->x)); + h2->y = half(static_cast(l2->y) * s_mean * static_cast(g2->y)); + h3->x = half(static_cast(l3->x) * s_mean * static_cast(g3->x)); + h3->y = half(static_cast(l3->y) * s_mean * static_cast(g3->y)); + h4->x = half(static_cast(l4->x) * s_mean * static_cast(g4->x)); + h4->y = half(static_cast(l4->y) * s_mean * static_cast(g4->y)); + + output[index] = tmp; + } +} + +template +__global__ void rmsnorm_twoPassAlgo_e1(T* output, + const T* input, + const T* weight, + const int m, const int n, + float epsilon) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val * local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T weight_val = weight[index]; + const T local_val = input[index]; + output[index] = T(static_cast(local_val) * s_mean * static_cast(weight_val)); + } +} + +template +void rmsnorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_weight, + cudaStream_t stream, float epsilon = 1e-5f){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* weight = ref_weight.data(); + dim3 grid(m); + + if (n % 8 == 0 && std::is_same::value) { + dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32)); + + rmsnorm_twoPassAlgo_e8<<>>( + (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); + } else { + dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32)); + + rmsnorm_twoPassAlgo_e1<<>>( + output, input, weight, m, n, epsilon); + } + + auto result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl; + abort(); + } +} + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9747d50975d7d35df287f6b056aedc489adb317c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief utils code for device cutlass code +*/ + +#pragma once + +#include +#include +#define FINAL_MASK 0xffffffff + +struct half4 { + half x, y, z, w; +}; + +template +__inline__ __device__ T warpReduceSum(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T* val) +{ + __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSum(val); + return (T)0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceMax(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); + } + warpReduceMax(val); + + return (T)0.0f; +} + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h new file mode 100644 index 0000000000000000000000000000000000000000..6565aba9607ad68defacb6e98d9f9bbc944cd48d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief This header contains a class to parametrize a statistical distribution function. +*/ + +#include + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Distribution type +struct Distribution { + /// Variant types + enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; + + /// Distribution state + union { + /// Uniform distribution + struct { + double min; + double max; + // Percent elements set to NaN + double pnan; + } uniform; + + /// Gaussian distribution + struct { + double mean; + double stddev; + double pnz; + double pnzA; + double pnzB; + double pnzC; + } gaussian; + + /// Elements are linear combination of row and column index + struct { + double start; + double delta; + } sequential; + }; + + /// Active variant kind + Kind kind; + + /// Random values are cast to integer after scaling by this power of two + int int_scale; + + // + // Methods + // + + Distribution() : kind(Invalid), int_scale(0) {} + +/// Configures distribution as uniform random + Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) { + kind = Uniform; + uniform.min = _min; + uniform.max = _max; + int_scale = _int_scale; + uniform.pnan = _pnan; + return *this; + } + + /// Configures distribution as Gaussian distribution + Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) { + kind = Gaussian; + gaussian.mean = _mean; + gaussian.stddev = _stddev; + gaussian.pnz = _pnz; + gaussian.pnzA = _pnz; + gaussian.pnzB = _pnz; + gaussian.pnzC = _pnz; + int_scale = _int_scale; + return *this; + } + + /// Sets identity + Distribution &set_identity() { + kind = Identity; + return *this; + } + + /// Sets sequential + Distribution &set_sequential(double start, double delta, int _int_scale = 0) { + kind = Sequential; + sequential.start = start; + sequential.delta = delta; + int_scale = _int_scale; + return *this; + } +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints a Distribution to ostream +inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { + switch (dist.kind) { + case cutlass::Distribution::Uniform: + out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max + << ", pnan: " << dist.uniform.pnan; + break; + case cutlass::Distribution::Gaussian: + out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev + << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: " + << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC; + break; + case cutlass::Distribution::Identity: + out << "identity"; + break; + case cutlass::Distribution::Sequential: + out << "sequential"; + break; + default: + out << "unknown"; + } + + out << ", int_scale: " << dist.int_scale; + + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..f2b7df6cb1c465a312d76566768cb79fcdfffee4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ exception semantics for CUDA error codes + */ + +#include +#include +#include + +#include "cutlass/platform/platform.h" + +namespace cutlass { + +/// C++ exception wrapper for CUDA \p cudaError_t +class cuda_exception : public std::exception { + public: + /// Constructor + cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} + + /// Returns the underlying CUDA \p cudaError_t + cudaError_t cudaError() const { return err; } + + protected: + /// Explanatory string + const char* msg; + + /// Underlying CUDA \p cudaError_t + cudaError_t err; +}; + +/// Writes a cuda_exception instance to an output stream +inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { + return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); +} + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be2264466e350c062900a50e27e923847186d084 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp @@ -0,0 +1,369 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT command line parser to gather semantic modes, their stride order, and extents. +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +namespace cutlass { + +// Output shortcuts +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a << " "; + return os; +} + +struct GettCommandLine { + struct GettProblem { + using extent_type = int; + using stride_type = int64_t; + + // Row modes: appear in A and C/D + std::vector M; + std::vector ldAm; + std::vector ldCm; + + // Column modes: appear in B and C/D + std::vector N; + std::vector ldBn; + std::vector ldCn; + + // Reduction modes: appear in A and B + std::vector K; + std::vector ldAk; + std::vector ldBk; + + // Batch modes: appear in all in/out tensors + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + }; + + static GettProblem + parse(int argc, char const* argv[], bool parse_verbose = false) { + using extent_type = typename GettProblem::extent_type; + using stride_type = typename GettProblem::stride_type; + + cutlass::CommandLine cmd(argc, argv); + + // modeA + std::vector a_mode; + cmd.get_cmd_line_arguments("modeA", a_mode); + + // modeB + std::vector b_mode; + cmd.get_cmd_line_arguments("modeB", b_mode); + + // modeC + std::vector c_mode; + cmd.get_cmd_line_arguments("modeC", c_mode); + + + // mode_sizes + std::map mode_size; + // First, initialize all modes in a, b, c to make sure they're in map + for (char a : a_mode) mode_size[a] = 1; + for (char b : b_mode) mode_size[b] = 1; + for (char c : c_mode) mode_size[c] = 1; + + // Then, overwrite the ones in -extent + std::vector > extent_tokens; + cmd.get_cmd_line_argument_pairs("extents", extent_tokens); + for (auto e : extent_tokens) { + if (std::get<0>(e).size() > 1) { + std::cerr << "ERROR: Mode name must only be 1 character long.\n"; + print_usage(); + exit(1); + } + char label = std::get<0>(e)[0]; + int size = std::stoi(std::get<1>(e)); + mode_size[label] = size; + } + + // Print out symbolic modes and their extents + if (parse_verbose) { + std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n"; + for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n"; + } + + // + // Collect/Compute strides + // + + std::map mode_ldA; + std::map mode_ldB; + std::map mode_ldC; + + { + stride_type current; + + current = 1; + for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; } + + current = 1; + for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; } + + current = 1; + for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; } + } + + // + // Collect mode categories + // + + std::vector row_mode; // rows + std::vector col_mode; // columns + std::vector red_mode; // reductions + std::vector bat_mode; // batches + + { + std::vector a_label = a_mode; + std::vector b_label = b_mode; + std::vector c_label = c_mode; + + std::sort(std::begin(a_label), std::end(a_label)); + std::sort(std::begin(b_label), std::end(b_label)); + std::sort(std::begin(c_label), std::end(c_label)); + + // std::set_intersections to find semantic category of each symbolic mode + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(row_mode)); + + std::set_intersection(std::begin(b_label), std::end(b_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(col_mode)); + + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(b_label), std::end(b_label), + std::back_inserter(red_mode)); + + std::set_intersection(std::begin(row_mode), std::end(row_mode), + std::begin(col_mode), std::end(col_mode), + std::back_inserter(bat_mode)); + + // std::set_difference to remove batch modes from other semantic modes + for (char l : bat_mode) { + row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode)); + col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode)); + red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode)); + } + } + + // Print out the semantic association of each symbolic mode + if (parse_verbose) { + std::cout << " rows : " << row_mode << '\n'; + std::cout << " cols : " << col_mode << '\n'; + std::cout << " reds : " << red_mode << '\n'; + std::cout << " bats : " << bat_mode << '\n'; + } + + // + // Permute modes + // + + // Permute the batched modes to promote coalescing + // Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size + std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) { + return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1]) + < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + for (char l : bat_mode) { + L.push_back(mode_size[l]); + ldAl.push_back(mode_ldA[l]); + ldBl.push_back(mode_ldB[l]); + ldCl.push_back(mode_ldC[l]); + } + + // Permute the reduction modes to promote coalescing + // Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size + std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) { + return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1]) + < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector K; + std::vector ldAk; + std::vector ldBk; + for (char k : red_mode) { + K.push_back(mode_size[k]); + ldAk.push_back(mode_ldA[k]); + ldBk.push_back(mode_ldB[k]); + } + + // Permute the row modes to promote coalescing + // Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm + std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) { + return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1]) + < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]); + }); + // Compute sizes and strides of ordered row modes + std::vector M; + std::vector ldAm; + std::vector ldCm; + for (char m : row_mode) { + M.push_back(mode_size[m]); + ldAm.push_back(mode_ldA[m]); + ldCm.push_back(mode_ldC[m]); + } + + // Permute the col modes to promote coalescing + // Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn + std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) { + return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1]) + < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]); + }); + // Compute sizes and strides of ordered col modes + std::vector N; + std::vector ldBn; + std::vector ldCn; + for (char n : col_mode) { + N.push_back(mode_size[n]); + ldBn.push_back(mode_ldB[n]); + ldCn.push_back(mode_ldC[n]); + } + + if (parse_verbose) { + std::cout << "C_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " = A_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " * B_"; + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << '\n'; + + int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{}); + int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{}); + int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{}); + int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{}); + + std::cout << " M : (" << M_size << ") "; + for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " "; + std::cout << '\n'; + std::cout << " N : (" << N_size << ") "; + for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " "; + std::cout << '\n'; + std::cout << " K : (" << K_size << ") "; + for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " "; + std::cout << '\n'; + std::cout << " L : (" << L_size << ") "; + for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " "; + std::cout << '\n'; + + std::cout << " ldAm : " << ldAm << '\n'; + std::cout << " ldAk : " << ldAk << '\n'; + std::cout << " ldAl : " << ldAl << '\n'; + std::cout << " ldBn : " << ldBn << '\n'; + std::cout << " ldBk : " << ldBk << '\n'; + std::cout << " ldBl : " << ldBl << '\n'; + std::cout << " ldCm : " << ldCm << '\n'; + std::cout << " ldCn : " << ldCn << '\n'; + std::cout << " ldCl : " << ldCl << '\n'; + } + + return {M, ldAm, ldCm, + N, ldBn, ldCn, + K, ldAk, ldBk, + L, ldAl, ldBl, ldCl}; + } + + static void + print_usage() { + std::cout << + "GETT problem command line parser:\n" + " --modeA=\n" + " A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeB=\n" + " A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeC=\n" + " A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --extents=\n" + " A command delimited list of symbolic mode and its corresponding extent.\n" + " Extents are defaulted to 1 if any are not provided.\n\n" + + "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n"; + } +}; + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58d08b860c9e665d170fd022ed0d95875e029019 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp @@ -0,0 +1,116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +namespace cute +{ + +void +device_init(int device_id, bool quiet = false) +{ + cudaDeviceProp device_prop; + std::size_t device_free_physmem; + std::size_t device_total_physmem; + + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + + //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; + + if (!quiet) { + printf("Using device %d: %s (SM%d, %d SMs)\n", + device_id, device_prop.name, + device_prop.major * 10 + device_prop.minor, + device_prop.multiProcessorCount); + fflush(stdout); + } +} + +/** + * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. + */ +inline int +_ConvertSMVer2Cores(int major, int minor) +{ + // Defines for GPU Architecture types (using the SM version to determine + // the # of cores per SM + typedef struct { + int SM; // 0xMm (hexadecimal notation), M = SM Major version, + // and m = SM minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = { + {0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf("MapSMtoCores for SM %d.%d is undefined." + " Default to use %d Cores/SM\n", + major, minor, nGpuArchCoresPerSM[index - 1].Cores); + + return nGpuArchCoresPerSM[index - 1].Cores; +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h new file mode 100644 index 0000000000000000000000000000000000000000..4e7718059dfaea0c77d7ebf67789f307b4ca0cf6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief reorder data from the host side +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +/// This is needed for the interleaved integer tensor core kernels. The purpose +/// is to use skip the shared memory part in the epilogue. +template +void reorder_column(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + const int InstructionShapeCol = 8; + // 4 threads per Quad + const int ElementsPerThread = InstructionShapeCol / 4; + // 4 threads per Quad + const int ReorderedElementsPerThread = + Interleaved / 4; + + for (int n = 0; n < problem_size.n(); n++) { + for (int k = 0; k < problem_size.k(); k++) { + dest.at({k, (n / Interleaved) * Interleaved + + ((n % ReorderedElementsPerThread) / ElementsPerThread) * + InstructionShapeCol + + ((n % Interleaved) / ReorderedElementsPerThread) * + ElementsPerThread + + (n % ElementsPerThread)}) = src.at({k, n}); + } + } +} + +template +void reorder_convK(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + + TensorRef> mappedDest(dest.data(), dest.stride(0)); + TensorRef> mappedSrc(src.data(), src.stride(0)); + + reorder_column( + mappedDest, mappedSrc, problem_size); +} + +/// This is needed for the sparse tensor core kernels. The purpose +/// is to use ldmatrix to load from shared memory to the register file. +template +void reorder_meta(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + for (int m = 0; m < problem_size.m(); m++) { + for (int k = 0; k < problem_size.k(); k++) { + // First reorder the rows. + int group = (sizeof(Element) == 2) ? 32 : 16; + int interweave = (sizeof(Element) == 2) ? 4 : 2; + + int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; + int dest_col = k; + + // Next swizzle the 2x2 blocks from Z to N. + if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { + ++dest_row; + --dest_col; + } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { + --dest_row; + ++dest_col; + } + + dest.at({dest_row, dest_col}) = src.at({m, k}); + } + } +} +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..3226055ad0836e7a3059340ff16d54594987e0c8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensor { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRef; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorView; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + +private: + using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization + typename platform::conditional_t::value % 8 == 0, // Handle subbyte types + Element, uint8_t>>; + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; + static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; + static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; + static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; + static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + /// number of containers + size_t count_to_container_storage_unit_count(size_t count) { + return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; + } + +public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensor() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensor( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensor( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensor() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")"); +#endif + + device_.reset(); + host_.clear(); + + size_t count_container = count_to_container_storage_unit_count(count); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")"); +#endif + host_.resize(count_container); + + // Allocate memory + StorageUnit* device_memory = nullptr; + if (device_backed_) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")"); +#endif + device_memory = device_memory::allocate(count_container); + } + device_.reset(device_memory, device_backed_ ? count_container : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); + + if (static_cast(new_size_container) > host_.size()) { + reserve(new_size, device_backed_); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the logical number of elements stored in the host tensor + size_t size() const { + return layout_.capacity(extent_); + } + + /// Returns the logical capacity in terms of number of elements. May be larger than the size(). + LongIndex capacity() const { + return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; + } + + /// Gets pointer to host data + Element * host_data() { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to device data + Element * device_data() { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data + Element const * device_data() const { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + + /// Returns the layout object + Layout & layout() { + return layout_; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_.data(), device_.get(), device_.size()); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_.get(), host_.data(), host_.size()); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + host_.data(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + device_.get(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + device_.get(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + host_.data(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + reinterpret_cast(ptr_host), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + reinterpret_cast(ptr_device), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + reinterpret_cast(ptr_device), host_.data(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + reinterpret_cast(ptr_host), host_.data(), container_count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..ca770e4d76cfe2df16309baca0b2de8ab6de98c4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h @@ -0,0 +1,591 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensorPlanarComplex { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRefPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorViewPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + + private: + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensorPlanarComplex() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensorPlanarComplex( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensorPlanarComplex( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensorPlanarComplex() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated + + device_.reset(); + host_.clear(); + + host_.resize(count * 2); + + // Allocate memory + Element* device_memory = nullptr; + if (device_backed_) { + device_memory = device_memory::allocate(count * 2); + } + device_.reset(device_memory, device_backed_ ? count * 2 : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + + if (static_cast(new_size * 2) > host_.size()) { + reserve(new_size); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the number of elements stored in the host tensor + size_t size() const { + return host_.size() / 2; + } + + /// Returns the logical capacity based on extent and layout. May differ from size(). + LongIndex capacity() const { + return layout_.capacity(extent_); + } + + /// Stride between real and imaginary parts + LongIndex imaginary_stride() const { + return host_.size() / 2; + } + + /// Gets pointer to host data + Element * host_data() { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element * host_data_imag() { return host_.data() + imaginary_stride(); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } + + /// Gets pointer to host data with a pointer offset + Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to device data + Element * device_data() { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } + + /// Gets pointer to device data + Element const * device_data() const { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } + + /// Gets a pointer to the device data imaginary part + Element * device_data_imag() { return device_.get() + imaginary_stride(); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { + return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_real() { + return cutlass::TensorRef(host_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_imag() { + return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { + return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_real() { + return cutlass::TensorRef(device_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_imag() { + return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_real() { + return cutlass::TensorView(host_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_imag() { + return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_real() { + return cutlass::TensorView(device_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_imag() { + return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_data(), device_data(), imaginary_stride() * 2); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_data(), host_data(), imaginary_stride() * 2); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + host_data(), ptr_device_real, count); + + device_memory::copy_to_host( + host_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + device_data(), ptr_device_real, count); + + device_memory::copy_device_to_device( + device_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + device_data(), ptr_host_real, count); + + device_memory::copy_to_device( + device_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + host_data(), ptr_host_real, count); + + device_memory::copy_host_to_host( + host_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host_real, ///< source device memory + Element * ptr_host_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + ptr_host_real, device_data(), count); + + device_memory::copy_to_host( + ptr_host_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + ptr_device_real, device_data(), count); + + device_memory::copy_device_to_device( + ptr_device_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + ptr_device_real, host_data(), count); + + device_memory::copy_to_device( + ptr_device_imag, host_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host_real, ///< source host memory + Element * ptr_host_imag, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + ptr_host_real, host_data(), count); + + device_memory::copy_host_to_host( + ptr_host_imag, host_data_imag(), count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h new file mode 100644 index 0000000000000000000000000000000000000000..9cd62927432c65ce1f0187f46306f7e1198a1182 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief uncompress sparse matrix from the host side +*/ +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +// uncompress sparse tensor core A matrix +template +void uncompress(TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef tensor_e, int row, int col) { + // How many uncompressed data we can get with ElementE meta data + int DecompressedElementsPerElementE = + 256 / cutlass::sizeof_bits::value; + + // Process 4bit meta data a time + int step; + + // 1:2 or 2:4 or 4:8 + int a, b; + + if (cutlass::sizeof_bits::value == 4) { + step = 8; + a = 4; + b = 8; + } else if (cutlass::sizeof_bits::value == 8) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 16) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 32) { + step = 2; + a = 1; + b = 2; + } + + int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; + + for (int r = 0; r < row; ++r) { + for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { + + ElementE meta = tensor_e.at(MatrixCoord(r, c)); + + for (int i = 0; i < DecompressedElementsPerElementE; i += step) { + int e = (meta >> (i / step * 4)) & 0xf; + int idx0 = e & 0x3; + int idx1 = e >> 2; + + if (a == 1) idx0 = idx0 / 2; + + for (int ii = 0; ii < step; ii += ElementsPerE) { + int real_col = + c * DecompressedElementsPerElementE + i + ii; + int compressed_col = (real_col / b) * a; + + if (ii == (idx0 * ElementsPerE)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at(MatrixCoord(r, compressed_col + 1)); + } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at( + MatrixCoord(r, compressed_col + ElementsPerE + 1)); + } else { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + ElementA(0); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + ElementA(0); + } + } + } + } + } +} + +// uncompress ELL block sparse matrix +template +void uncompress_ell_block_sparse( + TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef ell_idx, + int rows, int cols, + int ell_num_cols, int ell_blocksize) { + + for (int r = 0; r < rows / ell_blocksize; ++r) { + for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { + + ElementE idx = ell_idx.at(MatrixCoord(r, c)); + + if (idx != -1) { + int row_begin = r * ell_blocksize; + int col_begin_real = idx * ell_blocksize; + int col_begin = c * ell_blocksize; + + for (int i = 0; i < ell_blocksize; ++i) { + for (int j = 0; j < ell_blocksize; ++j) { + uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = + tensor_a.at( + MatrixCoord(row_begin + i, col_begin +j)); + } + } + } + } + } +} + +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h new file mode 100644 index 0000000000000000000000000000000000000000..6b72b043fc0c1271cf9f12e5cb9a81d29659cb0a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h @@ -0,0 +1,38 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +// integer_sequence moved to cutlass/numeric_types.h + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..43f5a3f92d29f229703cc4c5f9071c11d0f89df4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -0,0 +1,472 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for mixed input data type kernels. +*/ + +#pragma once + +#include +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/mma_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cute/util/type_traits.hpp" + +namespace cutlass { + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleBroadCastLayout, + class ThrLayout> +__global__ void dequantize_kernel(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleBroadCastLayout const broadcasted_scale_layout, + ThrLayout thr_layout) { + using namespace cute; + + // Represent the full tensors to gmem elements. + // These are expected to have shape [MN, K, L] + cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); + cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr(q_buffer), operand_layout); + // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting + // It is expected that K % G == 0 + cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); + + // Assign 1 thread per element in the thread block + auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); // + auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) + + // Tile across the block + auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); + auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); + auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); + + auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); + auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); + auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); + auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); + + // Make a fragment of registers to hold gmem loads + cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); + cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); + cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); + cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); + cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); + + cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); + auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); + auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); + + const auto num_iters = cute::size<3>(tOpDq_gOpDq); + + for (int ii = 0; ii < num_iters; ++ii) { + const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); + if (thread_offset < cute::size<0>(operand_layout)) { + cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); + cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); + cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); + cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); + cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); + cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{}); + cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{}); + cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); + cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); + } + } +} + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout> +static void dequantize(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + int const group_size, + cudaStream_t &stream) { + using namespace cute; + + constexpr int tpb = 128; + auto thr_layout = make_layout(make_shape(Int{})); + + const auto num_rows = get<0>(shape(operand_layout)); + const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] + const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] + const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] + + if (num_rows != size<0>(scale_layout)) { + std::cerr << "Invalid first dimension for scales. Must match first dim for weights." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + const auto scale_stride0 = get<0>(stride(scale_layout)); + const auto scale_stride1 = get<1>(stride(scale_layout)); + const auto scale_stride2 = get<2>(stride(scale_layout)); + + auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); + auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); + auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); + + const auto blocks_x = gemm_k; + const auto blocks_y = batches; + + dim3 blocks(blocks_x, blocks_y, 1); + dequantize_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) { + + using StorageType = cutlass::int4b_t::Storage; + constexpr int pack = cute::sizeof_bits_v / 4; + const size_t host_buf_size = block_size / pack; + std::vector host_buf(host_buf_size); + cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size); + + for (auto&& d : host_buf) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; i++) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size); + return true; +} + +template +static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array *block_out, const size_t block_size) { + std::vector data_in(block_size); + std::vector> data_out(block_size); + + try { + cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + + for (size_t i = 0; i < block_size; i++) { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + } + + try { + cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} + +template +struct UnderlyingElement { + using type = T; +}; + +template +struct UnderlyingElement> { + using type = typename T::Element; +}; + +// Given a type of MMA instruction, compute a memory reordering atom that places all values +// owned by each thread in contiguous memory locations. This improves smem load vectorization, +// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order +// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. +// In addition, we can reorder the values across several MMA instructions to get even wider +// vectorization (AtomLayout parameter) and permute the values within each instruction to get +// more optimal conversion instruction sequences (ValLayout parameter). +template , + class ValLayout = cute::Layout> +constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) +{ + using namespace cute; + + static_assert(is_static_v, "ValLayout must be static"); + static_assert(is_static_v, "AtomLayout must be static"); + + // 1. Choose an MMA atom to access TV layout and MN shape + // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary + using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); + using MmaTraits = MMA_Traits; + auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); + auto tv_layout_mma = typename MmaTraits::ALayout{}; + static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); + + // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) + // Note: this assumes A is partitioned between warps along M mode + auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); + auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); + auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); + auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); + + // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization + auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); + + // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) + auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); + auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); + auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); + auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); + + return layout_atom; +} + +template +__global__ void reorder_tensor_kernel( + cute::Tensor S, + cute::Tensor D, + TiledCopy tiled_copy) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + + Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + + auto thread_copy = tiled_copy.get_slice(threadIdx.x); + Tensor tS = thread_copy.partition_S(gS); + Tensor tD = thread_copy.partition_D(gD); + + copy(tiled_copy, tS, tD); +} + +template +void reorder_tensor( + cute::Tensor S, + cute::Tensor D) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + static_assert(is_same_v, T>, "Type mismatch"); + + // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread + // This avoids a race condition when writing out subbyte types (e.g. int4b_t). + auto has_major_mode = [](auto s) { + return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; }); + }; + static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), + "Could not find stride-1 mode in destination layout"); + constexpr int N = shape_div(Int<8>{}, Int>{}); + auto val_layout = conditional_return(LayoutDst{}))>( + make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), + make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); + + // Make a tiled copy with a simple row-major thread order and above layout + int constexpr NumThreads = 128; + auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); + auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); + + // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper + using TileShape = Shape<_16>; + auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); + dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; + + reorder_tensor_kernel<<>>(S, D, tiled_copy); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +// In-place version +template +void reorder_tensor( + T const* src, + LayoutSrc const& layout_src, + T * dst, + LayoutDst const& layout_dst) +{ + using namespace cute; + reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), + make_tensor(make_gmem_ptr(dst), layout_dst)); +} + +// In-place version +template +void reorder_tensor( + T * data, + LayoutSrc const& layout_src, + LayoutDst const& layout_dst) +{ + using namespace cute; + cutlass::DeviceAllocation temp(size(layout_src)); + reorder_tensor(data, layout_src, temp.get(), layout_dst); + cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); +} + +#undef CUDA_CHECK + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp new file mode 100644 index 0000000000000000000000000000000000000000..811ba152ab7c6e8fafc1cebdbb3726798fd16b3c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp @@ -0,0 +1,570 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/container/array.hpp" // cute::array +#include "cutlass/conv/convolution.h" // cutlass::conv::Operator + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides without batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, int64_t> +make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with group mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, StrideIntT, cute::Int<0>> +make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides for convolutions + +// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) +// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order +// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout +// right in KTRSC order and can be coalesced to just k. +// We enforce this condition here with asserts. +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, cute::Int<0>> s, + cute::array shape_output, + cute::array stride_output, + cutlass::conv::Operator conv_op) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + static_assert(RankT_ >= 3u); + constexpr static int RankT = static_cast(RankT_); + + assert(stride_output[RankT-1] == 1); + cute::for_each(cute::make_seq{}, [&](auto i) { + assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]); + }); + + auto s_copy = s; + cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ? + stride_output[0] : + stride_output[RankT-2]; + return s_copy; +} + +// +// Activation tensor ((w, h, d, n), _1) for fprop kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nwc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_nwc[1]; + cute::get<0,1>(s_copy) = stride_nwc[0]; + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nhwc[3] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_nhwc[2-i]; + }); + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_ndhwc[3-i]; + }); + return s_copy; +} + +// +// Filter tensor (k, (_1, s, r, t)) for fprop kernel +// + +// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>> +make_cute_packed_stride( + cute::Stride, IntT>> s, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>> s, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>> s, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + +// +// Activation tensor (_1, (w, h, d, n)) for wgrad kernel +// +// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad +// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nwc[2] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::get<1,0>(s_copy) = stride_nwc[1]; + cute::get<1,1>(s_copy) = stride_nwc[0]; + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nwc in dgrad is ksc. + cute::get<1,0>(s_copy) = stride_nwc[0]; + cute::get<1,1>(s_copy) = stride_nwc[1]; + } + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad +// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nhwc[3] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_nhwc[2-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nhwc in dgrad is krsc. + cute::get<1,0>(s_copy) = stride_nhwc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_nhwc[i+1]; + }); + } + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_ndhwc[3-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_ndhwc in dgrad is ktrsc. + cute::get<1,0>(s_copy) = stride_ndhwc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1]; + }); + } + return s_copy; +} + +// +// NZPQ tensor (_1, nzpq) for wgrad kernel +// + +// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nqk[2] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nqk[1]; + return s_copy; +} + +// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_npqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_npqk[3] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_npqk[2]; + return s_copy; +} + +// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nzpqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nzpqk[4] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nzpqk[3]; + return s_copy; +} + + + +// +// Wgrad output tensor (k, (_1, s, r, t), _0) +// + +// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + + +// +// Wgrad output tensor ((_1, s, r, t), k, _0) +// + +// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ksc[0]; + cute::get<0,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c38ad3f710c18e5be1bb7e01dc66d7efcd2646d9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +// The computed infinity norm does not include +// any NaN column absolute-value sums. +struct matrix_inf_norm_result { + // Accumulate errors in double, as this is generally + // the highest precision that the examples use. + double inf_norm = 0.0; + bool found_nan = false; +}; + +// In theory, cute::Tensor, T> could be treated as a view type, +// and thus passed by value (as std::span or std::string_view would be). +// However, generic cute::Tensor are more like containers +// and thus are best passed by reference or const reference. +template +matrix_inf_norm_result +matrix_inf_norm(cute::Tensor const& host_matrix) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + error_type inf_norm = 0.0; + bool found_nan = false; + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(host_matrix); + const int64_t num_cols = cute::size<1>(host_matrix); + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for(int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += abs_fn(host_matrix(i, j)); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +// Infinity norm of (X - Y). +template +matrix_inf_norm_result +matrix_diff_inf_norm(cute::Tensor const& X, + cute::Tensor const& Y) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + assert(cute::size<0>(X) == cute::size<0>(Y)); + assert(cute::size<1>(X) == cute::size<1>(Y)); + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(X); + const int64_t num_cols = cute::size<1>(X); + + error_type inf_norm = 0.0; + bool found_nan = false; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for (int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += error_type(abs_fn(element_type(X(i,j)) - + element_type(Y(i,j)))); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + char const A_value_type_name[], + cute::Tensor const& A, + char const B_value_type_name[], + cute::Tensor const& B, + char const C_value_type_name[], + cute::Tensor const& C, + cute::Tensor const& C_ref) +{ + const auto [A_norm, A_has_nan] = matrix_inf_norm(A); + const auto [B_norm, B_has_nan] = matrix_inf_norm(B); + const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref); + const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref); + + const auto A_norm_times_B_norm = A_norm * B_norm; + const auto relative_error = A_norm_times_B_norm == 0.0 ? + diff_norm : (diff_norm / A_norm_times_B_norm); + + // For expected error bounds, please refer to the LAPACK Users' Guide, + // in particular https://netlib.org/lapack/lug/node108.html . + // Printing the infinity norm of C is a way to check + // that both the function being tested (C) + // and the reference implementation (C_ref) + // don't just do nothing (or fill with zeros). + using std::cout; + using cute::shape; + cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' + << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' + << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' + << std::scientific + << "Infinity norm of A: " << A_norm << '\n' + << "Infinity norm of B: " << B_norm << '\n' + << "Infinity norm of C: " << C_norm << '\n' + << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; + + if(A_norm_times_B_norm == 0.0) { + cout << "Mollified relative error: " << relative_error << '\n'; + } else { + cout << "Relative error: " << relative_error << '\n'; + } + + if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; + } + return relative_error; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + const char value_type_name[], + const cute::Tensor& A, + const cute::Tensor& B, + const cute::Tensor& C_computed, + const cute::Tensor& C_expected) +{ + return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, + value_type_name, C_computed, C_expected); +} + +// Take a CUTLASS HostTensor (or the like) as input, +// and return a const CuTe Tensor. +// This is useful for use with the above error printing functions. +// This implicitly "transposes" if the layout is RowMajor. +// Note that the HostTensor must be captured by nonconst reference +// in order for X.host_ref().data() to compile. +// (CUTLASS is a bit more container-y than CuTe.) +template +auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) +{ + // The tensors were created with post-transposed extents. + const auto extents = X.extent(); + const auto shape = cute::Shape{extents[0], extents[1]}; + // Both RowMajor and ColumnMajor only store one stride. + const int LDX = X.stride(0); + const auto strides = [&]() { + using input_layout_type = typename std::decay_t::Layout; + if constexpr (std::is_same_v) { + return cute::Stride{1, LDX}; + } + else { + static_assert(std::is_same_v); + return cute::Stride{LDX, 1}; + } + }(); + const auto layout = cute::make_layout(shape, strides); + auto X_data = X.host_ref().data(); + auto X_data_const = const_cast >(X_data); + return cute::make_tensor(X_data_const, layout); +}; + + +// Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE. +// This makes the return value suitable as the return value of main(). +template +int +print_relative_error( + std::size_t n, + T1 const& data, + T2 const& reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + using std::abs; using std::sqrt; + + // Use either double or complex for error computation + using value_type = cute::remove_cvref_t; + using error_type = std::conditional_t::value, + cute::complex, + double>; + + if (print_verbose) { + std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl; + } + + double eps = 1e-200; + + double tot_error_sq = 0; + double tot_norm_sq = 0; + double tot_ind_rel_err = 0; + double max_ind_rel_err = 0; + double max_diff = 0; + for (std::size_t i = 0; i < n; ++i) { + error_type val = data[i]; + error_type ref = reference[i]; + + double aref = abs(ref); + double diff = abs(ref - val); + double rel_error = diff / (aref + eps); + + // Individual relative error + tot_ind_rel_err += rel_error; + + // Maximum relative error + max_ind_rel_err = std::max(max_ind_rel_err, rel_error); + + // Maximum delta in value error + max_diff = std::max(max_diff, diff); + + // Total relative error + tot_error_sq += diff * diff; + tot_norm_sq += aref * aref; + + if (print_verbose) { + std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl; + } + } + + double ave_rel_err = tot_ind_rel_err / double(n); + if (print_error) { + printf("Average relative error: %.3e\n", ave_rel_err); + } + + if (print_error) { + printf("Maximum relative error: %.3e\n", max_ind_rel_err); + } + + if (print_error) { + printf("Maximum difference : %.3e\n", max_diff); + } + + double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps)); + if (print_error) { + printf("Vector relative error: %.3e\n", tot_rel_err); + } + + printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq)); + + return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE; +} + +// Overload for cute::Tensor<> +template +int +print_relative_error( + cute::Tensor data, + cute::Tensor reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + assert(size(data) == size(reference)); + return print_relative_error(static_cast(size(data)), + data, reference, + print_verbose, print_error, error_margin); +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h new file mode 100644 index 0000000000000000000000000000000000000000..8167c91bf2330d160a78ba210449357b395964ca --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +namespace cutlass { +namespace reference { +namespace detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template function to compute an inner product. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a + // host-only type +template +CUTLASS_HOST_DEVICE +Ctype inner_product(Atype a, Btype b, Ctype c) { + return Ctype(a) * Ctype(b) + c; +} + +/// Specialization for matrix multiplication with binary operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int bit = 0; bit < 32; bit++) { + accum += a[bit] ^ b[bit]; + } + return accum + c; +} + +/* +/// Specialization for matrix multiplication with signed 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} + +/// Specialization for matrix multiplication with unsigned 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} +*/ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Cast { + // Default behavior: convert to the destination type +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + static DstType apply(SrcType src) { return static_cast(src); }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static int8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(-128.f, fminf(127.f, src))); + }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static uint8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(0.f, fminf(255.f, src))); + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h new file mode 100644 index 0000000000000000000000000000000000000000..652d622586cb202ecfe69ac892978b649b5d1be7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + + int64_t prod = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - Index; i < Rank; ++i) { + prod *= int64_t(extent[i]); + } + + coord[Rank - Index - 1] = int(idx / prod); + + int64_t residual = idx % prod; + LinearToCoordinateHelper()(coord, residual, extent); + } +}; + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &) const { + coord[Rank - 1] = int(idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinate { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + LinearToCoordinateHelper()(coord, idx, extent); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..7c6f803c47f5c407cf058d40bc8274a448a36dc4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h @@ -0,0 +1,1549 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv2d device reference kernel +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + int c_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Get group id of currnet channel + int c_group_idx = C / c_per_group; + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + int k_group_idx = thread_k / k_per_group; + + if (thread_k < problem_size.K && k_group_idx == c_group_idx) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_z[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, Z, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + int64_t ZPQ = PQ * problem_size.Z; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nzpq = nzpq_start + m; + + thread_n[m] = int(nzpq / ZPQ); + + int64_t residual = nzpq % ZPQ; + thread_z[m] = int(residual / PQ); + + residual = residual % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && + d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + + if (thread_k < problem_size.K) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_z[m] < problem_size.Z && + thread_p[m] < problem_size.P && + thread_q[m] < problem_size.Q) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } // for (n) + + } + } // for (m) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nhw = nhw_start + m; + + thread_n[m] = int(nhw / HW); + + int64_t residual = nhw % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_d[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + int64_t DHW = HW * problem_size.D; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t ndhw = ndhw_start + m; + + thread_n[m] = int(ndhw / DHW); + + int64_t residual = ndhw % DHW; + thread_d[m] = int(residual / HW); + + residual = residual % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (z >= 0 && !(z % problem_size.stride_d) && + p >= 0 && !(p % problem_size.stride_h) && + q >= 0 && !(q % problem_size.stride_w)) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_d[m] < problem_size.D && + thread_h[m] < problem_size.H && + thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t rsc = rsc_start + n; + int64_t residual = rsc % SC; + + thread_r[n] = int(rsc / SC); + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_t[kThreadN]; + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + int64_t RSC = SC * problem_size.R; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t trsc = trsc_start + n; + + thread_t[n] = int(trsc / RSC); + + int64_t residual = trsc % RSC; + thread_r[n] = int(residual / SC); + + residual = residual % SC; + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int Z = 0; Z < problem_size.Z; ++Z) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_t = thread_t[n]; + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - filter_t; + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W && + thread_c[n] < problem_size.C) { + + element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (Q) + } // for (P) + } // for (Z) + } // for (N) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_t[n] < problem_size.T && + thread_r[n] < problem_size.R && + thread_s[n] < problem_size.S && + thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Conv2d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; + int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; + int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; + int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; + int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv2dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; + int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv3dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kDgrad: + return Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kWgrad: + return Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + default: break; + } + + return Status::kErrorNotSupported; +} + +/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kDgrad: + return Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kWgrad: + return Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + default: break; + } + + return Status::kErrorNotSupported; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..7d575d522c1dd87d51f9bc58d09786393c5cfea3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/kernel/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + kernel::Gemm< + TensorRef, + TensorRef, + TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum) { + + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp, + typename ConvertOp +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + static_assert( + TensorRefCollectionA::kRank == 2 && + TensorRefCollectionB::kRank == 2 && + TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), + batch_count + ); + + // Launch a GEMM kernel + kernel::BatchedGemm< + TensorRefCollectionA, + TensorRefCollectionB, + TensorRefCollectionC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + initial_accum + ); +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..bddf596214da62a7aa3177f758db3710dc1d2516 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + if (grid.y <= std::numeric_limits::max()) { + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } else { + // Using bigger thread tile size + int const kBigMblock = 4; + int const kBigNblock = 16; + + dim3 Bigblock(16, 8); + dim3 Biggrid( + (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), + (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kBigMblock, + kBigNblock + ><<< Biggrid, Bigblock >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..48819cf6eaa565b3ec41dbbf78ae244666fd8a65 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static int const kGemmPlanarComplexBlockSize = 4; + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +__global__ void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + int const kMblock = kGemmPlanarComplexBlockSize; + int const kNblock = kGemmPlanarComplexBlockSize; + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + complex accum[kMblock][kNblock]; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block < K; ++k_block) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC c_ij = ComplexC(); + + if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { + c_ij = tensor_c.at(coord); + } + + complex src{ + ScalarType(c_ij.real()), + ScalarType(c_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + ComplexC d_ij; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = kernel::kGemmPlanarComplexBlockSize; + int const kNblock = kernel::kGemmPlanarComplexBlockSize; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + 1); + + kernel::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..497a257d170c411d891942f62fa2c960453d03d5 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT device reference code +*/ +#pragma once + +#include + +namespace cutlass::reference::device { + +template < + class ATensor, + class BTensor, + class CTensor, + class DTensor, + class ElementAccumulator, + class ElementEpilogue> +__global__ static +void +gett_kernel( + DTensor D, + ATensor const A, + BTensor const B, + CTensor const C, + ElementEpilogue alpha, ElementEpilogue beta, + ElementAccumulator acc_init) +{ + using namespace cute; + + static_assert(DTensor::rank == 3, "(M,N,L)"); + static_assert(ATensor::rank == 3, "(M,K,L)"); + static_assert(BTensor::rank == 3, "(N,K,L)"); + static_assert(CTensor::rank == 3, "(M,N,L)"); + + assert(size<0>(A) == size<0>(D)); // M + assert(size<0>(C) == size<0>(D)); // M + assert(size<0>(B) == size<1>(D)); // N + assert(size<1>(C) == size<1>(D)); // N + assert(size<1>(A) == size<1>(B)); // K + assert(size<2>(A) == size<2>(D)); // L + assert(size<2>(B) == size<2>(D)); // L + assert(size<2>(C) == size<2>(D)); // L + + NumericConverter a_converter; + NumericConverter b_converter; + NumericConverter acc_converter; + NumericConverter source_converter; + NumericConverter output_converter; + + // Thread id to each element of D + for (int tid = threadIdx.x + blockDim.x * blockIdx.x; + tid < size(D); + tid += blockDim.x * gridDim.x) { + // (m,n,l) coordinate + auto mnl_coord = idx2crd(tid, product_each(shape(D))); + auto m = get<0>(mnl_coord); + auto n = get<1>(mnl_coord); + auto l = get<2>(mnl_coord); + + auto A_ml = A(m,_,l); + auto B_nl = B(n,_,l); + + ElementAccumulator accum = ElementAccumulator(0); + for (int k = 0; k < size<1>(A); ++k) { + ElementAccumulator a = a_converter(A_ml(k)); + ElementAccumulator b = b_converter(B_nl(k)); + accum += a * b; + } + + ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l))); + D(m,n,l) = output_converter(scaled_output); + } +} + +// Most general version +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +void +gett( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4); + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto K = get<2>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full tensors + auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L) + auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L) + auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L) + auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L) + + dim3 dimBlock(256); + dim3 dimGrid(240); + gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0)); +} + +} // namespace cutlass::reference::device diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..6e131126a336420a2b0e843e3ead3d89fce637fa --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/thread/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void Gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefA tensor_a, + TensorRefB tensor_b, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + AccumulatorType initial_accum) { + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Compute the general matrix product + thread::Gemm< + TensorRefA, + TensorRefB, + TensorRefC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void BatchedGemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefCollectionA tensor_collection_a, + TensorRefCollectionB tensor_collection_b, + ScalarType beta, + TensorRefCollectionC tensor_collection_c, + AccumulatorType initial_accum) { + + // Obtain batch ID + int batch_id = blockIdx.z; + + // Dereference based on batch_id + typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); + typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); + typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, + (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow + ); + + // Compute the general matrix product + thread::Gemm< + typename TensorRefCollectionA::TensorRef, + typename TensorRefCollectionB::TensorRef, + typename TensorRefCollectionC::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..149e4b2e00e2ac8130cee9dc189a539ba3a70297 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform random distribution +template +__global__ void TensorInitializeUniform( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + double range = dist.uniform.max - dist.uniform.min; + + double rnd = curand_uniform(&rng_state[threadIdx.x]); + + rnd = dist.uniform.min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + + tensor += ldm; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform distribution +template +__global__ void TensorInitializeGaussian( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + double rnd = curand_normal(&rng_state[threadIdx.x]); + + rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; + + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeLinear( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = + dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeIdentity( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = (c_idx == s_idx ? T(1) : T(0)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..3223cb2056ba6d88f47f7b117392a56e325d0ce7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/fast_math.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for general rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + int64_t product = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - RankRemaining; i < Rank; ++i) { + product *= size[i]; + } + + coord[Rank - 1 - RankRemaining] = index / product; + int64_t remaining = index % product; + + TensorForEachHelper(func, size, coord, remaining); + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for fastest changing rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + coord[Rank - 1] = index; + + if (coord < size) { + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element in a tensor's index space +template +__global__ void TensorForEach(Coord size, Params params = Params()) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t max_index = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + max_index *= size[i]; + } + + CUTLASS_PRAGMA_NO_UNROLL + while (index < max_index) { + Coord coord; + + detail::TensorForEachHelper(func, size, coord, index); + index += blockDim.x * gridDim.x; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element along a tensor's diagonal +template +__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; + + if (index < end) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] = index; + } + + func(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params) { + + Func func(params); + + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + + for (; index < capacity; index += blockDim.x * gridDim.x) { + ReferenceFactory::get(ptr, index) = func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..2e76fe52b06f9bb1a033c736f94fa01961ce664d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h @@ -0,0 +1,355 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + assert(M=N); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::Rank2KComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + fill_mode_c, + blas_mode, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..1999730f6d24e69aef152aa332fae68af57a9c40 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" + +#include "cutlass/util/distribution.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template +__global__ void BlockCompareEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (a != b) { + *equal = 0; + + return; + } + } +} + +template +__global__ void BlockCompareRelativelyEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (!relatively_equal(a, b, epsilon, nonzero_floor)) { + *equal = 0; + return; + } + } +} + +} // namespace kernel + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareEqual<<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareRelativelyEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareRelativelyEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareRelativelyEqual<<< grid, block, 0, stream >>>( + device_equal_flag, + ptr_A, + ptr_B, + capacity, + epsilon, + nonzero_floor + ); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // reference +} // cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..a19b42825f6efb4a39466fe1cfc182ab7d831079 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -0,0 +1,2075 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) + +// Standard Library includes +#include +#include +#include +#include +#include + +#endif + +// CUDA includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_view.h" +#include "cutlass/blas3.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/layout/vector.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" +#include "cutlass/util/distribution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +CUTLASS_DEVICE +FloatType random_normal_float(curandState_t *state) { + return curand_normal(state); +} + +template <> +CUTLASS_DEVICE +double random_normal_float(curandState_t *state) { + return curand_normal_double(state); +} + +template +CUTLASS_DEVICE +FloatType random_uniform_float(curandState_t *state) { + return curand_uniform(state); +} + +template <> +CUTLASS_DEVICE +double random_uniform_float(curandState_t *state) { + return curand_uniform_double(state); +} + +template +struct RandomGaussianFunc { + + using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element mean_ = 0, + Element stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd = random_normal_float(&rng_state); + rnd = params.mean + params.stddev * rnd; + + Element result; + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd += FloatType(1); + } else { + rnd -= FloatType(1); + } + result = Element(rnd); + } + + return result; + } +}; + + +template +struct RandomGaussianFunc> { + + using Element = complex; + using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Real mean_ = 0, + Real stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_normal_float(&rng_state); + FloatType rnd_i = random_normal_float(&rng_state); + rnd_r = params.mean + params.stddev * rnd_r; + rnd_i = params.mean + params.stddev * rnd_i; + + Element result; + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r += FloatType(1); + } else { + rnd_r -= FloatType(1); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomGaussianFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomGaussianFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = typename RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type mean = Element(0), ///< Gaussian distribution's mean + typename RealType::Type stddev = Element(1), ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + using Func = detail::TensorFillRandomGaussianFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template ///< Element type +void BlockFillRandomGaussian( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type mean, ///< Gaussian distribution's mean + typename RealType::Type stddev, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + + typename RandomFunc::Params params(seed, mean, stddev, bits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random uniform distribution +template ///< Element type +struct RandomUniformFunc { + + using FloatType = typename std::conditional< + (sizeof(Element) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Element) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType max; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element max_ = 1, + Element min = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max_) - static_cast(min)), + max(static_cast(max_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + range = (min == Element(0)) ? range - FloatType(1): range; + max = (max_ == Element(0)) ? max - FloatType(1): max; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(NAN); + } + } + + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.max - params.range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd = std::min(params.max, rnd + FloatType(1)); + } else { + rnd = std::max((params.max - params.range), rnd - FloatType(1)); + } + result = Element(rnd); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template +struct RandomUniformFunc> { + + using Element = complex; + + using FloatType = typename std::conditional< + (sizeof(Real) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Real) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType min; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max = 1, + FloatType min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max - min_)), + min(static_cast(min_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + min = (min == FloatType(0)) ? min + FloatType(1): min; + range = (max == FloatType(0)) ? range - FloatType(1): range; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(Real(NAN), Real(NAN)); + } + } + + FloatType rnd_r = random_uniform_float(&rng_state); + FloatType rnd_i = random_uniform_float(&rng_state); + + rnd_r = params.min + params.range * rnd_r; + rnd_i = params.min + params.range * rnd_i; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1)); + } else { + rnd_r = std::max((params.min), rnd_r - FloatType(1)); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomUniformFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type max = Element(1), ///< upper bound of distribution + typename RealType::Type min = Element(0), ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + + typename RandomFunc::Params params(seed, max, min, bits, pnan); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random sparse meta +template ///< Element type +struct RandomSparseMetaFunc { + + using FloatType = float; + + using IntType = int32_t; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + int MetaSizeInBits; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), + MetaSizeInBits(MetaSizeInBits_) { + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomSparseMetaFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element *MetaArray = + (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.range * rnd; + Element meta = MetaArray[(int)rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomSparseMetaFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, MetaSizeInBits); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + + typename RandomFunc::Params params(seed, MetaSizeInBits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Updates the tensor + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + params.view.at(coord) = (is_diag ? params.diag : params.other); + } +}; + +// Overwrites the elements of a tensor with a uniform value depending on fill mode +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element element; + FillMode fill_mode; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params(): fill_mode(FillMode::kNone) { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, + Element element_, + FillMode fill_mode_ + ): + view(view_), element(element_), fill_mode(fill_mode_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorFillPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + case FillMode::kFull: + predicate = true; + break; + + case FillMode::kLower: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] < coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kUpper: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] > coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kDiagonal: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] != coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorClearPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// + static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); + + /// Parameters structure + struct Params { + TensorView view{}; + Element element{}; + FillMode fill_mode{FillMode::kNone}; + int alignment{0}; + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorClearPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + + case FillMode::kLower: + if ((coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kUpper: + if ((coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0), ///< value to write off the diagonal + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are +/// not written. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, element, fill_mode), + stream + ); +} + +/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side +/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorClearPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + int alignment, + cudaStream_t stream = nullptr) { + + typedef detail::TensorClearPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params{view, element, fill_mode, alignment}, + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView view, ///< destination tensor + Element val = Element(0), ///< value to uniformly fill it with + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, val, val, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView view, ///< destination tensor + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, Element(1), Element(0), stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1) + ): + view(view_), diag(diag_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (is_diag) { + params.view.at(coord) = params.diag; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + params.view.at(coord) = params.other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView view, ///< destination tensor + Element other = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateOffDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillLinearFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element sum = params.s; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + if constexpr (is_complex::value) { + if constexpr (sizeof_bits::value <= 32) { + sum = Element(static_cast>(sum) + + static_cast>(params.v[i]) * static_cast>(coord[i])); + } + } + else if constexpr (sizeof_bits::value <= 32) { + if constexpr (std::numeric_limits::is_integer) { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + else { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + } + else { + sum += params.v[i] * coord[i]; + } + } + + params.view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView view, ///< destination tensor + Array const & v, + Element s = Element(0), + cudaStream_t stream = nullptr) { + + using Func = detail::TensorFillLinearFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, v, s), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr, + int exclude_zero = -1 ///< If non-negative, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. + ) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + exclude_zero, + stream); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + exclude_zero, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + + using Layout = layout::PackedVectorLayout; + Layout::TensorCoord size(static_cast(capacity)); // -Wconversion + Layout layout = Layout::packed(size); + TensorView view(ptr, layout, size); + + Array c{}; + c[0] = v; + + TensorFillLinear(view, c, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + stream); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalInFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element const *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element const *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { + + } + + /// Only update the diagonal element + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.view.at(coord) = params.ptr[coord[0]]; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView view, ///< destination tensor + Element const *ptr, ///< dense buffer of elements + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalInFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalOutFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.ptr[coord[0]] = params.view.at(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView view, ///< source tensor + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalOutFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..ba2dfd85c47b8c9450c348de32dccb7f1be9c3c1 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/reference/device/kernel/tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element in a tensor's index space. +template +struct TensorForEach { + + /// Constructor performs the operation. + TensorForEach( + Coord size, Params params = Params(), + int grid_size = 0, int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::TensorForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::TensorForEach<<< grid, block, 0, stream >>>(size, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element along a tensor's diagonal +template +struct TensorDiagonalForEach { + + /// Constructor performs the operation + TensorDiagonalForEach( + Coord size, Params params = Params(), + int start = 0, int end = -1, + int block_size = 128, cudaStream_t stream = nullptr) { + + if (end < 0) { + end = size.min(); + } + + dim3 block(block_size, 1, 1); + dim3 grid((end - start + block_size - 1) / block_size, 1, 1); + + kernel::TensorDiagonalForEach<<< grid, block, 0, stream >>>( + size, params, start, end); + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params(), + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockForEach<<< grid, block, 0, stream >>>(ptr, capacity, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..3e6d7b300f34fec6aec96e72f78427cf677936b4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h @@ -0,0 +1,514 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/detail/linear_to_coordinate.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t size = view.size(); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + + // Fetch element + Element x = view.at(coord); + + // Transform + identity = reduce(identity, transform(x)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + auto size = static_cast(view_A.size()); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + + // Fetch element + Element a = view_A.at(coord); + Element b = view_B.at(coord); + + // Transform + identity = reduce(identity, transform(a, b)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + + +template < + typename ComputeType, + typename ReduceOp, + int kBlockSize = 32 +> +__global__ void TensorTransformReduceFinalize( + ComputeType *workspace, + ComputeType identity, + int workspace_size, + ReduceOp reduce) { + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { + identity = reduce(identity, workspace[idx]); + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[0] = identity; + } +} + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of two tensors, zipped together +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Extents must be equal."); + } + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view_A, view_B, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view_A, + view_B, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform, stream, workspace_size); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h new file mode 100644 index 0000000000000000000000000000000000000000..0e3d99ddf845810249f909fbdee4505a0a732c4f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorReLuFunc { + + /// View type + using TensorView = TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element threshold; + + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element threshold_ = Element(0) + ): + view(view_), threshold(threshold_) { + + } + }; + + // + // Data members + // + + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorReLuFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element const & value = params.view.at(coord); + params.view.at(coord) = (value < params.threshold) ? params.threshold : value; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Apply ReLu on a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorReLu( + TensorView view, ///< destination tensor + Element threshold = Element(0)) { ///< ReLu threshold + + using Func = detail::TensorReLuFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, threshold) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..dd11f96bd92f6995590e61665e41a3e830bceacd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h @@ -0,0 +1,186 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace thread { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level blocked general matrix product. +// +// Note, this is a reference implementation. Performance is not expected to approach peak. +// +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +struct Gemm { + + using ElementA = typename TensorRefA::Element; + using ElementB = typename TensorRefB::Element; + using ElementC = typename TensorRefC::Element; + + // + // Data members + // + + /// Tile for A operand + ElementA A_tile[OutputTile::kColumn]; + + /// Tile for B operand + ElementB B_tile[OutputTile::kRow]; + + /// Tile for Accumulator + AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; + + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { + + // Clear fetch registers + for (int i = 0; i < OutputTile::kColumn; ++i) { + A_tile[i] = ElementA(0); + } + + for (int j = 0; j < OutputTile::kRow; ++j) { + B_tile[j] = ElementB(0); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kColumn; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kRow; ++i) { + accum[j][i] = initial_accum; + } + } + } + + /// Computes a matrix product + CUTLASS_HOST_DEVICE + Gemm & multiply_add( + gemm::GemmCoord problem_size, + TensorRefA tensor_a, + TensorRefB tensor_b, + MatrixCoord output_coord = MatrixCoord()) { + + InnerProductOp inner_product_op; + + // Loop over the GEMM K dimension + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < problem_size.k(); ++k) { + + // Fetch a slice of the A matrix + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + if (output_coord.row() + i < problem_size.m()) { + A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); + } + } + + // Fetch a slice of the B matrix + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + if (output_coord.column() + j < problem_size.n()) { + B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); + } + } + + // Compute an accumulated matrix product + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); + } + } + } + + return *this; + } + + /// Performs linear scaling of matrix product and updates output tensor + CUTLASS_HOST_DEVICE + Gemm & epilogue( + gemm::GemmCoord problem_size, + ScalarType alpha, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + MatrixCoord output_coord = MatrixCoord()) { + + ConvertOp convert_op; + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[j][i]) + + beta * ScalarType(tensor_c.at(coord)) + ); + } + } + } + + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..57443325629ea4e5d855fe18f94c73b10a71a73a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -0,0 +1,782 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for CONV in host-side code. +*/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cute/tensor.hpp" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<5>(activation)) && + (n_ >= 0 && n_ < size<4>(activation)) && + (d_ >= 0 && d_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<4>(activation)) && + (n_ >= 0 && n_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<3>(activation)) && + (n_ >= 0 && n_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +} // namespace detail + +template< + class ElementAcc_, + class ElementScalar_, + class ElementCompute_, + class ElementC_, + class ElementOut_, + bool ResidualAdd_, + class TensorAlpha_, + class TensorBeta_, + class TensorBias_, + class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> +struct ConvEpilogueFusionParams { + using ElementAcc = ElementAcc_; + using ElementScalar = ElementScalar_; + using ElementCompute = ElementCompute_; + using ElementC = ElementC_; + using ElementOut = ElementOut_; + using TensorAlpha = TensorAlpha_; + using TensorBeta = TensorBeta_; + using TensorBias = TensorBias_; + using ActivationFunctor = ActivationFunctor_; + static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorAlpha tensor_alpha{}; + TensorBeta tensor_beta{}; + TensorBias tensor_bias{}; +}; + +template< + cutlass::conv::Operator ConvOp, + int NumSpatialDims, + class TensorA, + class TensorB, + class TensorC, + class TensorD, + class ShapePadding, + class StrideTraversal, + class ShapeDilation, + class EpilogueFusionParams +> +struct ConvReferenceImpl { + // Hard code accumlulator type to float to avoid data lost in accumulating add. + using ElementAcc = cutlass::platform::conditional_t, double, float>; + using ElementC = typename EpilogueFusionParams::ElementC; + using ElementOut = typename EpilogueFusionParams::ElementOut; + using ElementScalar = typename EpilogueFusionParams::ElementScalar; + using ElementCompute = typename EpilogueFusionParams::ElementCompute; + using ElementBias = typename EpilogueFusionParams::TensorBias::value_type; + using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor; + + // Input related converter + NumericConverter acc_converter; + NumericConverter residual_converter; + NumericConverter bias_converter; + // Scale related converter + NumericConverter scale_converter; + // Output related converter + NumericConverter output_converter; + + EpilogueFusionParams& epi_fusion_params_; + TensorA const& tensor_a_; + TensorB const& tensor_b_; + TensorC const& tensor_c_; + TensorD& tensor_d_; + + ShapePadding const& padding_; + StrideTraversal const& tstride_; + ShapeDilation const& dilation_; + + // Epilogue activation operation + ActivationFunctor epi_activation; + + ConvReferenceImpl( + TensorA const& tensor_a, + TensorB const& tensor_b, + TensorC const& tensor_c, + TensorD& tensor_d, + ShapePadding const& padding, + StrideTraversal const& tstride, + ShapeDilation const& dilation, + EpilogueFusionParams& epi_fusion_params) + : tensor_a_(tensor_a), + tensor_b_(tensor_b), + tensor_c_(tensor_c), + tensor_d_(tensor_d), + padding_(padding), + tstride_(tstride), + dilation_(dilation), + epi_fusion_params_(epi_fusion_params) + { + static_assert(rank(ShapePadding{}) == rank(ShapeDilation{})); + static_assert(rank(ShapePadding{}) == rank(StrideTraversal{})); + } + + void compute_reference() { + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + fprop_reference(cute::Int{}); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + dgrad_reference(cute::Int{}); + } + else { + wgrad_reference(cute::Int{}); + } + } + +private: + // Specialization for 1D fprop kernel + void fprop_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) { + auto a = tensor_a_(c, w, n, g); + auto b = tensor_b_(c, s, k, g); + accumulator += ElementAcc(a * b); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + tensor_d_(k, q, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D fprop kernel + void fprop_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) { + auto a = tensor_a_(c, w, h, n, g); + auto b = tensor_b_(c, s, r, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + tensor_d_(k, q, p, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D fprop kernel + void fprop_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t Z = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) { + auto a = tensor_a_(c, w, h, d, n, g); + auto b = tensor_b_(c, s, r, t, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + tensor_d_(k, q, p, z, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D dgrad kernel + void dgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g)); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + tensor_d_(c, w, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D dgrad kernel + void dgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g)); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + + tensor_d_(c, w, h, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D dgrad kernel + void dgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t D = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<4>(tensor_b_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t d = 0; d < D; ++d) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (z % cute::get<2>(tstride_) == 0) { + z /= cute::get<2>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g)); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + tensor_d_(c, w, h, d, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D wgrad kernel + void wgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, n, g); + auto xformed_act = + tensor_a_(k, q, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + tensor_d_(c, s, k, g) = output_converter(output); + } + } + } + } + } + + // Specialization for 2D wgrad kernel + void wgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, n, g); + auto xformed_act = + tensor_a_(k, q, p, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + tensor_d_(c, s, r, k, g) = output_converter(output); + } + } + } + } + } + } + + // Specialization for 3D wgrad kernel + void wgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = + size<4>(tensor_a_); + int32_t Z = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t T = size<3>(tensor_d_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0 ; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, d, n, g); + auto xformed_act = + tensor_a_(k, q, p, z, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + tensor_d_(c, s, r, t, k, g) = output_converter(output); + } + } + } + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..73298e5794f0f2658ef18fb3f46466c400fc831e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h @@ -0,0 +1,802 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Forward propagation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv2d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + int group_idx = k / (problem_size.K / problem_size.groups); + int channels_per_group = problem_size.C / problem_size.groups; + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < channels_per_group; ++c) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); + ElementB b = tensor_w.at({k, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +/// Depthwise-separable convolution +template , + typename InnerProductOp = multiply_add> +void Depsep_Fprop(cutlass::TensorView tensor_A, + cutlass::TensorView tensor_B, + cutlass::TensorView tensor_C, + cutlass::TensorView tensor_D, + ElementCompute alpha, + ElementCompute beta, + cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), + cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), + cutlass::Coord<2> dilation = cutlass::Coord<2>(), + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < tensor_C.extent().n(); ++n) { + for (int p = 0; p < tensor_C.extent().h(); ++p) { + for (int q = 0; q < tensor_C.extent().w(); ++q) { + for (int g = 0; g < tensor_C.extent().c(); ++g) { + ElementAccumulator acc = ElementAccumulator(); + for (int r = 0; r < tensor_B.extent().h(); ++r) { + for (int s = 0; s < tensor_B.extent().w(); ++s) { + + // input activation H and W + int h = p * conv_stride[0] - padding[0] + r * dilation[0]; + int w = q * conv_stride[1] - padding[2] + s * dilation[1]; + + if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { + ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); + + ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) + ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) + : tensor_B.at(cutlass::make_Coord( + g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); + tensor_D.at(cutlass::make_Coord(n, p, q, g)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dDgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; +#if 0 + std::cout << "row:" + << n * problem_size.H * problem_size.W + + h * problem_size.W + + w << " " + << "n, p, q: (" + << n << ", " + << p << ", " + << q << ") * " + << "r, s: (" + << r << ", " + << s << ") [" + << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" + << std::endl; +#endif + if (p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, r, s, c)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dWgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + cutlass::Tensor4DCoord b_coord; + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + b_coord = make_Coord( + n, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + acc = inner_product_op(a, b, acc); + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (K) +} + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// 3D convolution +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv3d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (d >= 0 && d < problem_size.D && + h >=0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, d, h, w, c}); + ElementB b = tensor_w.at({k, t, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dDgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int d = 0; d < problem_size.D; ++d) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (z >= 0 && (z % problem_size.stride_d) == 0 && + p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + } // for (T) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (D) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dWgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + Tensor5DCoord b_coord = make_Coord( + n, + z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && + b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + + acc = inner_product_op(a, b, acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + } // for (K) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h new file mode 100644 index 0000000000000000000000000000000000000000..12ead83354b785096e8029b49f1ac353d5ce5f82 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h @@ -0,0 +1,66 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorRelativeErrorMetric( + TensorView view_A_computed, + TensorView view_B_reference, + ComputeType identity = ComputeType() +) { + + return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / + cutlass::reference::host::TensorNorm(view_B_reference, identity); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..2afee7b36d9822cc196f0f167f9dbec4c295d1a6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h @@ -0,0 +1,531 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +namespace cutlass { +namespace reference { +namespace host { + +template +struct CastIfScalar { + static Out cast(In in) { + return Out(in); + } +}; + +template +struct CastIfScalar, In> { + typedef cutlass::complex Out; + static Out cast(In in) { + return Out(static_cast(in)); + } +}; + +template +struct CastIfScalar, cutlass::complex> { + typedef cutlass::complex Out; + typedef cutlass::complex In; + static Out cast(In in) { + return Out(in); + } +}; + +template +Out cast_if_scalar(In in) { + return CastIfScalar::cast(in); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +/// Partial specialization for AND-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); + typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); + typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); + + for (int batch = 0; + batch < batch_count; + ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { + + Gemm + gemm; + + gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, + initial_accum); + } +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..221a6040854a74ce465af7b021bbbfae9b96a90b --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -0,0 +1,210 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..507c37d9eb5a8c998f1075d547e8430b2edc5685 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -0,0 +1,228 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + complex accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC d_ij = tensor_c.at(coord); + + complex src{ + ScalarType(d_ij.real()), + ScalarType(d_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd54dc6e378d0d0f0549ec922da8357841ac558f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -0,0 +1,916 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/gemm.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" + +#include "cute/tensor.hpp" +#include "cute/pointer.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_ // (N, K, L) + + , class TensorSfA_ = TensorA_, + class TensorSfB_ = TensorB_ + +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + TensorA A{}; + TensorB B{}; + + ComplexTransform transform_A = ComplexTransform::kNone; + ComplexTransform transform_B = ComplexTransform::kNone; + + + using TensorSfA = TensorSfA_; + using TensorSfB = TensorSfB_; + using EngineSfA = typename TensorSfA::engine_type; + using LayoutSfA = typename TensorSfA::layout_type; + using EngineSfB = typename TensorSfB::engine_type; + using LayoutSfB = typename TensorSfB::layout_type; + TensorSfA_ SfA{}; + TensorSfB_ SfB{}; + + + GettMainloopParams() {} + + GettMainloopParams(TensorA tensor_A, TensorB tensor_B) + : A(tensor_A), B(tensor_B) {} + + + GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : A(tensor_A), SfA(tensor_SfA), + B(tensor_B), SfB(tensor_SfB) {} + + +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorSfA_, // (M, K, L) + class TensorB_, // (N, K, L) + class TensorSfB_ // (N, K, L) +> +struct GettBlockScalingMainloopParams : public GettMainloopParams { + using Base = GettMainloopParams; + using ElementAccumulator = typename Base::ElementAccumulator; + using TensorA = typename Base::TensorA; + using TensorB = typename Base::TensorB; + using EngineA = typename Base::EngineA; + using LayoutA = typename Base::LayoutA; + using EngineB = typename Base::EngineB; + using LayoutB = typename Base::LayoutB; + ComplexTransform transform_A = Base::transform_A; + ComplexTransform transform_B = Base::transform_B; + + using TensorSfA = typename Base::TensorSfA; + using TensorSfB = typename Base::TensorSfB; + using EngineSfA = typename Base::EngineSfA; + using LayoutSfA = typename Base::LayoutSfA; + using EngineSfB = typename Base::EngineSfB; + using LayoutSfB = typename Base::LayoutSfB; + + GettBlockScalingMainloopParams() {} + + GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {} + + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class SfStrategy { + None = 0, + SfDGen = 1 +}; + + +/////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementScalingFactor_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class TensorAux_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, N, L) + class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class TensorSFD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false + , + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; + using TensorSFD = TensorSFD_; + using SFD_VectorSize = SFD_VectorSize_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + using EngineSfD = typename TensorSFD::engine_type; + using LayoutSfD = typename TensorSFD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + static constexpr SfStrategy SfGenStrategy = SfGenStrategy_; + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + TensorSFD SfD{}; + ElementCompute st = ElementCompute(1); + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; + GettEpilogueParams() {} + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {} + + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {} + + + GettEpilogueParams( + ElementScalar alpha, ElementScalar beta, + TensorC tensor_C, TensorD tensor_D, + VectorBias bias, TensorAux tensor_aux, + VectorAlpha vector_alpha, VectorBeta vector_beta) + : alpha(alpha), beta(beta), + C(tensor_C), D(tensor_D), + Bias(bias), Aux(tensor_aux), + Valpha(vector_alpha), Vbeta(vector_beta) {} +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, + class TensorD_, + class TensorSfD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettBlockScalingEpilogueParams : public GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // class BiasBinaryOp_ = + false, //PerColumnBias_ + SfGenStrategy_ // SfGenStrategy + > { + using Base = GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // BiasBinaryOp + false, // PerColumnBias + SfGenStrategy_ // SfGenStrategy + >; + using ElementScalar = typename Base::ElementScalar; + using ElementScalingFactor = typename Base::ElementScalingFactor; + using ElementAccumulator = typename Base::ElementAccumulator; + using ElementCompute = typename Base::ElementCompute; + using TensorC = typename Base::TensorC; + using TensorD = typename Base::TensorD; + using TensorAux = typename Base::TensorAux; + using VectorBias = typename Base::VectorBias; + using VectorAlpha = typename Base::VectorAlpha; + using VectorBeta = typename Base::VectorBeta; + using TensorSFD = typename Base::TensorSFD; + using SFD_VectorSize = typename Base::SFD_VectorSize; + using ActivationFunctor = typename Base::ActivationFunctor; + using BiasBinaryOp = typename Base::BiasBinaryOp; + + using EngineC = typename Base::EngineC; + using LayoutC = typename Base::LayoutC; + using EngineD = typename Base::EngineD; + using LayoutD = typename Base::LayoutD; + using EngineSfD = typename Base::EngineSfD; + using LayoutSfD = typename Base::LayoutSfD; + static constexpr bool PerColumnBias = Base::PerColumnBias; + static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy; + + GettBlockScalingEpilogueParams() {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : Base(alpha, beta, tensor_C, tensor_D) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {} +}; + + + + + +/////////////////////////////////////////////////////////// +// +// Generic Gett 3x Implementation +// +/////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +void compute_1d_scaling_factor_and_quantized_output( + EpilogueParams const& epilogue_params, + TensorD &tensor_D, + TensorSFD &tensor_SfD, + int64_t m, + int64_t n, + int64_t l, + ElementCompute (&acc)[kBlockM][kBlockN]) +{ + using ElementD = typename ElementTraits::type; + using ElementSfD = typename ElementTraits::type; + + int const M = cute::size<0>(tensor_D.layout()); + int const N = cute::size<1>(tensor_D.layout()); + int const L = cute::size<2>(tensor_D.layout()); + + auto mul = cutlass::multiplies{}; + auto div = divides{}; + // Get FP max + ElementCompute fp_max = ElementCompute(std::numeric_limits::max()); + float scale_down_factor = div(1.0f, fp_max); + // Get st' = st / FP max + ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor); + + absolute_value_op abs_op; + maximum_with_nan_propogation max_op; + + if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) { + // MN major output + int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize); + // Col major output + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t col = n + n_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_row = m + kVectorSize * v_b; + if (sf_row < M && col < N) { + tensor_SfD(sf_row, col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale); + } + } + } + } + } + else { + int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize); + // row major output + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t row = m + m_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_col = n + kVectorSize * v_b; + + if (row < M && sf_col < N) { + tensor_SfD(row, sf_col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale); + } + } + } + } + } +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; + + + using ElementSFA = typename ElementTraits::type; + using ElementSFB = typename ElementTraits::type; + + + using RingOp = multiply_add; + RingOp fma_op; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFA + auto sfa = static_cast(mainloop_params.SfA(m + m_b, k, l)); + a_frag[m_b] *= sfa; + } + + + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { + a_frag[m_b] = conj(a_frag[m_b]); + } + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFB + auto sfb = static_cast(mainloop_params.SfB(n + n_b, k, l)); + b_frag[n_b] *= sfb; + } + + + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { + b_frag[n_b] = conj(b_frag[n_b]); + } + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); + } + } + + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementSfD = typename EpilogueParams::TensorSFD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy; + + constexpr bool IsScalingAndAmaxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsReLUAuxNeeded = + (cute::is_same_v> or + cute::is_same_v>) and + cute::is_same_v; + constexpr bool UseReLU = + cute::is_same_v>; // Treat Clamp as ReLU + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + NumericConverter bias_converter; + [[maybe_unused]] NumericConverter aux_source_converter; + + // Scale related converter + NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + + // Output related converter + NumericConverter destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; + NumericConverter dBias_converter; + + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + plus add; + + // Activation operation + ActivationFunctor activation; + + // Bias binary operation + BiasBinaryOp bias_op; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + + ElementCompute inter_accum[kBlockM][kBlockN]; + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + // vector alpha + if (raw_pointer_cast(epilogue_params.Valpha.data())) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l)); + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + } + ElementCompute output = mul(converted_alpha, converted_acc); + + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); + output = bias_op(output, converted_bias); + } + + if (raw_pointer_cast(epilogue_params.C.data())) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // vector beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l)); + converted_beta = mul(converted_beta, converted_scale_c); + } + output = epilogue_fma(converted_beta, converted_src, output); + } + + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); + } + + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); + } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } + + if constexpr (UseReLU) { + cutlass::epilogue::thread::ReLU relu; + output = relu(output); + } + else { + output = activation(output); + } + } + + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + + inter_accum[m_b][n_b] = ElementCompute(output); + } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } + } + } // m_b + + if constexpr ( + SfGenStrategy == SfStrategy::SfDGen + ) { + // 1d scale factor generation + constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{}; + if (epilogue_params.SfD.data() != nullptr) { + compute_1d_scaling_factor_and_quantized_output(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum); + } + } + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } + +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +auto make_layout_rank3(const TensorType& tensor) { + // append a batch mode of size 1 if we do not have tensors that are rank 3 + return make_layout( + make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}), + make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout())))); +} + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); + + if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { + cute::Layout layout_A = make_layout_rank3(mainloop_params.A); + cute::Layout layout_B = make_layout_rank3(mainloop_params.B); + cute::Layout layout_C = make_layout_rank3(epilogue_params.C); + cute::Layout layout_D = make_layout_rank3(epilogue_params.D); + cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); + cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); + cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); + cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); + + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); + auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); + auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); + auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); + auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux); + auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); + auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha); + auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta); + + // Reconstruct mainloop params + GettMainloopParams + mainloop_params_converted{TensorA, + TensorB, + mainloop_params.transform_A, + mainloop_params.transform_B}; + + // Reconstruct epilogue params + GettEpilogueParams + epilogue_params_converted{epilogue_params.alpha, + epilogue_params.beta, + TensorC, + TensorD, + VectorBias, + TensorAux, + VectorAlpha, + VectorBeta, + epilogue_params.abs_amax_D, + epilogue_params.abs_amax_Aux, + epilogue_params.scale_a, + epilogue_params.scale_b, + epilogue_params.scale_c, + epilogue_params.scale_d, + epilogue_params.scale_aux + }; + + Gett(mainloop_params_converted, epilogue_params_converted); + } + else { + // if we already have a batch mode, just pass it through + Gett(mainloop_params, epilogue_params); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..67867533d5783b6e0047ac2110dc47adaa277e25 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for Rank 2k update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert( + FillModeC == FillMode::kLower || + FillModeC == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), + std::greater_equal, + std::less_equal>::type; + + // Note: batch is ignored. + // Note: M is same as N for Rank 2k update + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < N; row_block += Nblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Nblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < N && col < N && compare_op(row, col)) + { + + // A x B^T + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b_t(cast_if_scalar(b_t)); + + accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); + + // B x A^T + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType compute_b(cast_if_scalar(b)); + ComputeType compute_a_t(cast_if_scalar(a_t)); + + accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < N && col < N && + ( (FillModeC == FillMode::kLower && row >= col) || + (FillModeC == FillMode::kUpper && row <= col) ) + ) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_rank2k( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Rank2K; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Rank2K { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a738101660f7ebbdd7c7796d46df244f1e3f5f70 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h @@ -0,0 +1,318 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + } + } + } + } + + /* HER2K need two epilogues to handle complex alpha value */ + if ( blas_mode == BlasMode::kHermitian ) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op(alpha * + ScalarType(accum[i][j]) + + beta * c); + } + } + } + + /* Zeoring out accum for second HERK */ + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); + } + } + } + } + + ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? + conj(alpha) : alpha; + ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? + 1 : beta; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType d = (blas_mode == BlasMode::kHermitian) ? + tensor_d.at(coord) : tensor_c.at(coord); + + ScalarType tmp_d = convert_op( + alpha_hermitian * ScalarType(accum[i][j]) + + beta_hermitian * d); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..1aad33fd643b60752bc0845e403cebc43ad7d047 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x A^T (Symmetric) or A x A^H (Hermitian) + // complex conjugation on operandB (a_t) (function of blas3 computation) + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))) : + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(a_t); + + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + ScalarType tmp_d = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void RankKComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h new file mode 100644 index 0000000000000000000000000000000000000000..34f9648f25f8965f6730999b7763220c360683a8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for SYMM update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + FillModeA == FillMode::kLower || + FillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (SideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + + ComputeType compute_a_1(cast_if_scalar(a_1)); + ComputeType compute_b_1(cast_if_scalar(b_1)); + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (SideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + } + + ComputeType compute_a_2(cast_if_scalar(a_2)); + ComputeType compute_b_2(cast_if_scalar(b_2)); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_symm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Symm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Symm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..79e146f69b784a92ce61a093f410e93a66005cf8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued SYMM update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = BlasMode::kSymmetric, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static SideMode const kSideModeA = SideModeA; + static FillMode const kFillModeA = FillModeA; + static BlasMode const kBlasMode = BlasMode_; + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(kSideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + kFillModeA == FillMode::kLower || + kFillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) + { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (kSideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (kSideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + ComputeType compute_a_1 = ComputeType(a_1); + ComputeType compute_b_1 = ComputeType(b_1); + + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { + compute_a_1 = real(compute_a_1); + } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { + compute_b_1 = real(compute_b_1); + } + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (kSideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + if (kBlasMode == BlasMode::kHermitian) + a_2 = conj(a_2); + } else if (kSideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + if (kBlasMode == BlasMode::kHermitian) + b_2 = conj(b_2); + } + + ComputeType compute_a_2 = ComputeType(a_2); + ComputeType compute_b_2 = ComputeType(b_2); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + ScalarType c = tensor_c.at(coord); + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct SymmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..d6b85ca1baf65ba811b7c8b3a224ca90bbce1680 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h @@ -0,0 +1,616 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorGreatestErrorFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double result; + + /// Ctor + TensorGreatestErrorFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + result(0.0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + result = std::max(result, std::abs(double(lhs_) - double(rhs_))); + } + + /// Returns true if equal + operator double() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMREFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + static constexpr double epsilon = 1e-6; + + /// Ctor + TensorMREFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::abs(double(lhs_) - double(rhs_) / (double(rhs_) + epsilon)); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMSEFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + + /// Ctor + TensorMSEFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::pow((double(lhs_) - double(rhs_)), 2); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + bool result; + + /// Ctor + TensorEqualsFunc(): result(true) { } + + /// Ctor + TensorEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), rhs(rhs_), result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (lhs_ != rhs_) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorRelativelyEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + Element epsilon; + Element nonzero_floor; + bool result; + + /// Ctor + TensorRelativelyEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_, + Element epsilon_, + Element nonzero_floor_ + ) : + lhs(lhs_), + rhs(rhs_), + epsilon(epsilon_), + nonzero_floor(nonzero_floor_), + result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Squared Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMSE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMSEFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Relative Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMRE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMREFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the greatest error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorGreatestError( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorGreatestErrorFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorView const &lhs, + TensorView const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc func(lhs, rhs, epsilon, nonzero_floor); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are NOT equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return true; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return !bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + return !TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorContainsFunc { + + // + // Data members + // + + TensorView view; + Element value; + bool contains; + Coord location; + + // + // Methods + // + + /// Ctor + TensorContainsFunc(): contains(false) { } + + /// Ctor + TensorContainsFunc( + TensorView const &view_, + Element value_ + ) : + view(view_), value(value_), contains(false) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + if (view.at(coord) == value) { + if (!contains) { + location = coord; + } + contains = true; + } + } + + /// Returns true if equal + operator bool() const { + return contains; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if a value is present in a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorContains( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return bool(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of +/// of the first occurrence. If the value is not contained in the tensor, the second element of the +/// pair is undefined. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +std::pair > TensorFind( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return std::make_pair(bool(func), func.location); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp new file mode 100644 index 0000000000000000000000000000000000000000..27ef969b4ff2b6d8f3a53f3d1a3e5ec3e5203ec3 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename TensorL, + typename TensorR +> +bool TensorEquals( + TensorL lhs, + TensorR rhs) { + + // Extents must be identical + if (cute::size(lhs) != cute::size(rhs)) { + return false; + } + + for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { + if (lhs(idx) != rhs(idx)) { + return false; + } + } + + return true; +} + +/// Returns true if two tensor views are NOT equal. +template < + typename TensorL, + typename TensorR +> +bool TensorNotEquals( + TensorL lhs, + TensorR rhs) { + + return TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h new file mode 100644 index 0000000000000000000000000000000000000000..d2a43b1295c8ab18c7d649c79b0364b6d3e7c48c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h @@ -0,0 +1,256 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to convert between types +template < + typename DstElement, + typename SrcElement +> +struct TrivialConvert { + + TrivialConvert() { } + + DstElement operator()(SrcElement src) const { + return DstElement(src); + } +}; + +/// Helper to conditionally copy between tensor views. +template < + typename DstElement, + typename DstLayout, + typename SrcElement, + typename SrcLayout, + typename F +> +struct TensorCopyIf { + + using DstTensorView = TensorView; + using SrcTensorView = TensorView; + + // + // Data members + // + + DstTensorView dst; + SrcTensorView src; + F convert; + + // + // Methods + // + + TensorCopyIf() { } + + TensorCopyIf( + DstTensorView const &dst_, + SrcTensorView const &src_, + F const &convert_): dst(dst_), src(src_), convert(convert_) {} + + /// Copies based on destination and source bounds + void operator()(Coord const &coord) { + if (dst.contains(coord) && src.contains(coord)) { + dst.at(coord) = convert(src.at(coord)); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + CopyIf copy_if(dst, src, transform); + + TensorForEach(dst.extent(), copy_if); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView src_view(src, dst.extent()); + + CopyIf copy_if(dst, src_view, transform); + + TensorForEach(dst.extent(), copy_if); +} + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorRef dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView dst_view(dst, src.extent()); + + CopyIf copy_if(dst_view, src, transform); + + TensorForEach(src.extent(), copy_if); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorView dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorRef dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..5470df29358799f6d5e6628e8722f0e3dc05485f --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to apply a binary operator in place +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementD, + typename LayoutD, + typename BinaryFunc> +struct TensorFuncBinaryOp { + + // + // Data members + // + + /// View of left-hand-side tensor + TensorView view_d; + TensorRef view_a; + TensorRef view_b; + BinaryFunc func; + + // + // Methods + // + + /// Constructor + TensorFuncBinaryOp() { } + + /// Constructor + TensorFuncBinaryOp( + TensorView const & view_d_, + TensorRef const & view_a_, + TensorRef const & view_b_, + BinaryFunc func = BinaryFunc() + ): + view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } + + /// Equality check + void operator()(Coord const &coord) const { + view_d.at(coord) = func( + ElementD(view_a.at(coord)), + ElementD(view_b.at(coord)) + ); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Adds two tensors and stores in the destination tensor: d = a + b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::plus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Adds a tensor in place: d = d .+ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorAdd(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Subtracts two tensors and stores in the destination tensor: d = a - b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference + ) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::minus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Subtracts two tensors in place: d = d .- a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference + ) { + + TensorSub(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Multiplies two tensors and stores in the destination tensor: d = a .* b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::multiplies + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Multiplies tensors in place: d = d .* a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorMul(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..645902f7dd7b62bc98a479e4956dfb4b437d46a7 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -0,0 +1,1718 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element value; + + // + // Methods + // + + TensorFillFunc( + TensorView const &view_ = TensorView(), + Element value_ = Element(0) + ): view(view_), value(value_) { } + + void operator()(Coord const & coord) const { + view.at(coord) = value; + } +}; + +/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method +struct BoxMullerFunc { + + BoxMullerFunc() {} + + void operator()( + double* rnd, ///< Size-2 vector to be filled with random values + double mean = 0, ///< Mean of the Gaussian distribution + double stddev = 1, ///< Standard deviation of the Gaussian distribution + double pi = std::acos(-1)) const { + + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); + rnd[0] = mean + stddev * rnd[0]; + rnd[1] = mean + stddev * rnd[1]; + } +}; +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView dst, ///< destination tensor + Element val = Element(0)) { ///< value to uniformly fill it with + + detail::TensorFillFunc func(dst, val); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorViewPlanarComplex dst, ///< destination tensor + cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with + + TensorFill(dst.view_real(), val.real()); + TensorFill(dst.view_imag(), val.imag()); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + } + else { + result = static_cast(0); + } + + // Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros + if (exclude_zero && result == Element(0)) { + if (rnd > 0) { + rnd += 1; + } else { + rnd -= 1; + } + result = Element(rnd); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + double rnd[2]; + detail::BoxMullerFunc func; + func(rnd, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale))); + rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale))); + reals[0] = from_real(rnd[0] / double(1 << int_scale)); + reals[1] = from_real(rnd[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd[0]); + reals[1] = from_real(rnd[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0.0) && + reals[1] == from_real(0.0)) { + + if (rnd[0] > 0.0) { + rnd[0] += 1.0; + } else { + rnd[0] -= 1.0; + } + reals[0] = from_real(rnd[0]); + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + double rnd1[2]; + double rnd2[2]; + detail::BoxMullerFunc func; + func(rnd1, mean, stddev, pi); + func(rnd2, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale))); + + reals[0] = from_real(rnd1[0] / double(1 << int_scale)); + reals[1] = from_real(rnd1[1] / double(1 << int_scale)); + reals[2] = from_real(rnd2[0] / double(1 << int_scale)); + reals[3] = from_real(rnd2[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd1[0]); + reals[1] = from_real(rnd1[1]); + reals[2] = from_real(rnd2[0]); + reals[3] = from_real(rnd2[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + reals[2] = from_real(0); + reals[3] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0) && + reals[1] == from_real(0) && + reals[2] == from_real(0) && + reals[3] == from_real(0)) { + + if (rnd1[0] > 0.0) { + rnd1[0] += 1.0; + } else { + rnd1[0] -= 1.0; + } + reals[0] = from_real(rnd1[0]); + } + + return Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + view.at(coord) = func(); + } +}; + +/// Computes a random Gaussian distribution for a rank-2 tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillSymmetricGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz, exclude_zero); + + detail::TensorFillGaussianFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + detail::TensorFillSymmetricGaussianFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) + { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + Element operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + if (exclude_zero && result == Element(0)) { + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + complex operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + + if (exclude_zero && + i == 0 && + reals[0] == from_real(0.0)) { + + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + reals[0] = from_real(Real(rnd)); + } + + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_), + bernoulli_rnd{static_cast(seed_)}, + bernoulli_dist(pnan_) + { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + + view.at(coord) = func(); + } +}; + +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillSymmetricRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +/// Computes a random Uniform distribution and pads diagonal with zeros +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPadDiagonalRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + int alignment; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillPadDiagonalRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, + int alignment_ = 1 + ): + view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + (fill_mode == cutlass::FillMode::kLower) && + (coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= alignment)) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + (coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= alignment)) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan, exclude_zero); + + detail::TensorFillRandomUniformFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero); + TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero); +} + + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView, Layout> dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc> random_func(seed, max, min, bits); + + detail::TensorFillRandomUniformFunc, Layout> func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillSymmetricRandomUniformFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPadDiagonalRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int alignment = 1 +) { + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillPadDiagonalRandomUniformFunc func( + dst, + random_func, + fill_mode, + alignment + ); + + TensorForEach( + dst.extent(), + func + ); +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element ///< Element type +> +void BlockFill( + Element *ptr, + size_t capacity, + Element val + ) { + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = val; + } +} + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0) { ///< Percentage of NaN elements. + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + // + // Methods + // + + TensorFillDiagonalFunc( + TensorView const &view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + view.at(coord) = (is_diag ? diag : other); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView dst, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0)) { ///< value to write off the diagonal + + detail::TensorFillDiagonalFunc func( + dst, + diag, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to fill a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView dst) { ///< destination tensor + + TensorFillDiagonal(dst, Element(1), Element(0)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView dst, ///< destination tensor + Element val = Element(1)) { + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = val; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element other; + + // + // Methods + // + + TensorUpdateOffDiagonalFunc( + TensorView const &view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + view.at(coord) = other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView dst, ///< destination tensor + Element other = Element(1)) { + + detail::TensorUpdateOffDiagonalFunc func( + dst, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + // + // Methods + // + + TensorFillLinearFunc() { } + + /// Constructs functor + TensorFillLinearFunc( + TensorView const &view_, + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { } + + /// Updates the tensor + void operator()(Coord const & coord) const { + + Element sum(s); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + sum += Element(coord[i]) * v[i]; + } + + view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView dst, ///< destination tensor + Array const & v, + Element s = Element(0)) { + + detail::TensorFillLinearFunc func( + dst, + v, + s + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSequential( + TensorView dst, ///< destination tensor + Element s = Element(0)) { + + Array stride; + + stride[0] = Element(1); + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); + } + + TensorFillLinear(dst, stride, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + bool exclude_zero = false ///< If true, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. +) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz, + exclude_zero); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan, + exclude_zero); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = s; + + s = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = Element(s); + + s = int64_t(s + v) % mod; + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist) { + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomSparseMetaFunc { + + uint64_t seed; + int range; + int MetaSizeInBits; + + // + // Methods + // + + RandomSparseMetaFunc( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), MetaSizeInBits(MetaSizeInBits_) { + std::srand((unsigned)seed); + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + + /// Compute random value and update RNG state + Element operator()() const { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + int rnd = std::rand() % range; + Element meta = MetaArray[rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random sparse meta +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomSparseMetaFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillRandomSparseMetaFunc( + TensorView view_ = TensorView(), + RandomSparseMetaFunc func_ = RandomSparseMetaFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + + view.at(coord) = func(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4 bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + detail::TensorFillRandomSparseMetaFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a ell block index matrix with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomEllIdx( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int rows, int ell_cols, int cols) { ///< dimension of the matrix + + std::srand((unsigned)seed); + + for (int i = 0; i < rows; ++i) { + int col_idx = std::rand() % cols; + + for (int j = 0; j < ell_cols; ++j) { + dst.at({i, j}) = col_idx; + + if (col_idx != -1) { + if (col_idx == (cols - 1)) { + col_idx = -1; + } else { + col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView dst, ///< destination tensor + Element const *ptr) { ///< dense buffer of elements + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = ReferenceFactory::get(ptr, i); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView src) { ///< source tensor + + typename Layout::Index extent = src.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + ReferenceFactory::get(ptr, i) = src.at(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1b3df239a1b9d69fc12e7ec4be2de6f87b3a0e3c --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Uniform and procedural tensor fills +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a scalar element +template +void TensorFill(Tensor dst, typename Tensor::value_type element) { + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = element; + } +} + +/// Fills a tensor with the contents of its layout +template +void TensorFillSequential(Tensor dst) { + + auto layout = dst.layout(); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = layout(idx); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random uniform values +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Element operator()() const { + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template ///< Tensor object +void TensorFillRandomUniform( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random Gaussian +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Tensor +> +void TensorFillRandomGaussian( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..bcb1af995805e3fbcbdbf398ce7191ea2f0dbe8d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - RankRemaining - 1; + + /// Constructor for general rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + TensorForEachHelper(func, extent, coord); + } + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - 1; + + /// Constructor for fastest changing rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEach(Coord extent, Func & func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor and calls a C++ lambda +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEachLambda(Coord extent, Func func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params()) { + + Func func(params); + + for (size_t index = 0; index < capacity; ++index) { + ptr[index] = func(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..d44dda1f5472f13b7212f7e2e4020e254ff92f88 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" + +// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. + +#include "cutlass/util/reference/host/tensor_reduce.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..887c568059a90f749fc0ac75dd211ce77085a5a9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/reference/detail/linear_to_coordinate.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) { + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + Element x = view.at(coord); + identity = reduce(identity, transform(x)); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Tensor extents must match."); + } + + for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) { + + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + Element a = view_A.at(coord); + Element b = view_B.at(coord); + identity = reduce(identity, transform(a, b)); + } + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea711466df86703aae1702605a928754c9f4e944 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tensor reductions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Tensor, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + Tensor view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < cute::size(view); ++idx) { + identity = reduce(identity, transform(view(idx))); + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename TensorA, + typename TensorB, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorA view_A, + TensorB view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (cute::size(view_A) != cute::size(view_B)) { + throw std::runtime_error("Tensor sizes must match."); + } + + for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { + identity = reduce(identity, transform(view_A(idx), view_B(idx))); + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSum( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSumSq( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Tensor, + typename ComputeType = double +> +ComputeType TensorNorm( + Tensor view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h new file mode 100644 index 0000000000000000000000000000000000000000..09b1aff9c0ea9922af46c928a3dd61595be2e4cd --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Trmm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Trmm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..e8db2a4deaf8608882595d68e611f8ae79e134e8 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + // Conjugate, and hence hermitian, is only allowed for the triangular matrix + if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct TrmmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h new file mode 100644 index 0000000000000000000000000000000000000000..0ce1d8a65fdd66ace69f91525b678dd6ad132d24 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include "cutlass/core_io.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/complex.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteLeastSignificantRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + out << ScalarIO(view.at(coord)); + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by "\n" + if (idx) { + out << ",\n"; + } + TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << ",\n\n"; + } + TensorView_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + + complex x = view.at(coord); + out << x; + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by ";\n" + if (idx) { + out << ";\n"; + } + TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << "\n"; + } + TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..5dfbfe274dec368cfac291a1c78ece6ffb203c72 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Type traits for common CUDA types +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "cutlass/complex.h" + +namespace cutlass { +struct half_t; + +template +struct TypeTraits { + typedef T host_type; + typedef T device_type; + static inline T remove_negative_zero(T x) { return x; } + static inline T to_print(T x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int8_t host_type; + typedef int8_t device_type; + typedef int8_t integer_type; + typedef uint8_t unsigned_type; + static inline int8_t remove_negative_zero(int8_t x) { return x; } + static inline int to_print(int8_t x) { return (int)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint8_t host_type; + typedef uint8_t device_type; + typedef uint8_t integer_type; + typedef uint8_t unsigned_type; + static inline uint8_t remove_negative_zero(uint8_t x) { return x; } + static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef int host_type; + typedef int device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline int32_t remove_negative_zero(int32_t x) { return x; } + static inline int to_print(int x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef unsigned host_type; + typedef unsigned device_type; + typedef uint32_t integer_type; + typedef uint32_t unsigned_type; + static inline uint32_t remove_negative_zero(uint32_t x) { return x; } + static inline uint32_t to_print(uint32_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int64_t host_type; + typedef int64_t device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline int64_t remove_negative_zero(int64_t x) { return x; } + static inline int64_t to_print(int64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint64_t host_type; + typedef uint64_t device_type; + typedef uint64_t integer_type; + typedef uint64_t unsigned_type; + static inline uint64_t remove_negative_zero(uint64_t x) { return x; } + static inline uint64_t to_print(uint64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_16F; + typedef half_t host_type; + typedef half_t device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline half_t remove_negative_zero(half_t x) { + return (x.raw() == 0x8000 ? half_t::bitcast(0) : x); + } + static inline half_t to_print(half_t x) { return x; } + static inline device_type to_device(half_t x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32F; + typedef float host_type; + typedef float device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; } + static inline float to_print(float x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_64F; + typedef double host_type; + typedef double device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; } + static inline double to_print(double x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Complex types +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0_hf ? 0_hf : real(x), + imag(x) == -0_hf ? 0_hf : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + + static cudaDataType_t const cublas_type = CUDA_C_32F; + typedef complex host_type; + typedef complex device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.f ? 0.f : real(x), + imag(x) == -0.f ? 0.f : imag(x) + ); + } + + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_64F; + typedef complex host_type; + typedef complex device_type; + struct integer_type { int64_t real, imag; }; + struct unsigned_type { uint64_t real, imag; }; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.0 ? 0.0 : real(x), + imag(x) == -0.0 ? 0.0 : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py new file mode 100644 index 0000000000000000000000000000000000000000..6541ce1b26722ff1f0dba0b4c034067a62f9b96d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py @@ -0,0 +1,356 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +""" +Given a set of test files to be included in a CMake target, this script extracts +the TEST definitions from each file, writes them into new files, and prints the names +of the new files so that they can be processed as part of a new CMake target. + +For example, given a set of --src_files test_a.cu test_b.cu containing 3 and 2 TEST +definitions, respectively, this script would produce: + test_a_000.cu + test_a_001.cu + test_a_002.cu + test_b_000.cu + test_b_001.cu + +The splitting follows a fairly rudimentary algorithm that does not support all valid C++ programs. +We walk through a given input test file line by line. Any lines that are not within a TEST definition is added to a running +"filler" text. When a TEST definition is encountered, the current filler text becomes the prefix +for that test. All subsequent lines are considered to be part of the TEST definition until the +number of starting function braces ('{') match the number of closing function braces ('}'). When +these counts are equal, the TEST definition is considered to be completed. At this point, we return +to adding lines to the "filler" text until a new TEST definition is encountered. Any "filler" text +following a TEST definition is added to the suffix of that TEST definition (this is useful for finishing +off #if statements, as is common in unit tests.). + +A state machine illustrating this algorithm at a high level is provided in the source below. + +Example: Suppose an input test `test.cu` has the following source: + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +The contents of the two resulting test files will be: + $ cat test_000.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + $ cat test_001.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +Notice that each of test_000.cu and test_001.cu contain comments that appear outside +the TEST definitions not included in each file. This is by design, as these +would be considered "filler" text. + +As expected, some cases can't be handled. Below is a non-exhaustive list: + 1. New TEST following the closing '}' of a TEST case on the same line: + TEST(x, y) { + // Do stuff + } TEST(a, b) { + + In this case, "TEST(a, b) {" will be ignored + + 2. Preprocessor macros that occur midway through a test case and extend + beyond the conclusion of a testcase + + Example: + TEST(a, b) { + // Do stuff + #if X + // Do more stuff + } + #else + // Do other stuff + } + #endif +""" + + +import argparse +import enum +import os + + +parser = argparse.ArgumentParser() +parser.add_argument("cmake_target", type=str, + help="Name of the CMake target being generated.") +parser.add_argument("src_dir", type=str, + help="Path to the directory containing test files.") +parser.add_argument("--src_files", nargs='+', + help="Files containing TEST instances to split.") +parser.add_argument("--max_tests_per_file", type=int, default=1, + help="Maximum number of TEST instances per file.") +parser.add_argument("--dst_dir", type=str, + help="Path to the directory to which to write new test files. If not set, uses src_dir.") +args = parser.parse_args() + + +if args.dst_dir == None: + args.dst_dir = args.src_dir + + +class Testcase: + """ + Lightweight tracker of test-case processing status + """ + def __init__(self, prefix_text): + # Any text that preceded the TEST definition that was + # not part of another TEST definition + self.prefix = prefix_text + + # Any text within the TEST definition + self.test = "" + + # Any text that follows the completion of the TEST definition + # and is not included in other TEST definitions + self.suffix = "" + + # Whether the test's definition has concluded + self.completed = False + + # Current balance of opening and closing curly brackets in + # the TEST definition. '{' increments the count and '}' decrements it. + # A value of 0 (when self.completed == False) indicates that the test + # has completed. + self.curly_bracket_balance = 0 + + +class ParseState(enum.Enum): + """ + State machine for processing. + Transitions occur on each line encountered in the soruce file + + + Line does not contain 'TEST(' + +----+ + | | + | v 'TEST(' + +--------+ encountered +--------------------------+ + ------>| Filler | -----------------------> | TestDeclaredWaitingStart | + +--------+ +--------------------------+ + ^ | + Number of '{' | | First '{' encountered + equals number of | +--------+ | + '}' encountered +-----------| InTest | <------------------+ + +--------+ + | ^ + | | + +----+ + Number of '{' encountered + exceeds number of '}' encountered + """ + + + # Any text that is not part of a TEST case + Filler = 0 + + # Processing text within the first { of the TEST case + # and before the en of the final } of the TEST case + InTest = 1 + + # Processing text from the start of the TEST definition + # but before the first {. This could occur if the opening { + # occurs on a separate line than the TEST definition. + TestDeclaredWaitingStart = 2 + + +cmake_src_list = [] +for filename in args.src_files: + if '.' not in filename: + # Add any non-filename arguments to the command list by default + cmake_src_list.append(filename) + continue + + if '/' in filename: + raise Exception( + f"Source files passed to {__file__} must be within the same directory " + "as the CMakeLists defining the target using the files. " + f"Provided path {filename} is in a different directory.") + + full_filename = os.path.join(args.src_dir, filename) + with open(full_filename, 'r') as infile: + lines = infile.readlines() + + # Find the number of instances of "TEST(" + ntest = sum([1 for line in lines if "TEST(" in line]) + + if ntest <= args.max_tests_per_file: + # File contains fewer than max_tests_per_file TEST instances. It does + # not need to be split + cmake_src_list.append(filename) + continue + + # Current state of the parsing state machine. We start with filler text + state = ParseState.Filler + + # List of individual TESTs found + tests = [] + + # Ongoing text that is not included in a TEST definition. This will serve + # as the prefix for any yet-to-be encountered TEST definitions. + filler_text = "" + + def add_filler_text(text): + global filler_text + # Add new text to the ongoing filler text and to the suffixes of + # any completed tests + filler_text += text + for i in range(len(tests)): + if tests[i].completed: + tests[i].suffix += text + + for line in lines: + if state == ParseState.Filler: + # We are not currently within a TEST definition. + + if 'TEST(' in line: + # We have encountered a new TEST( case. Any text preceding this + # must be added to the filler text (e.g., if we have a line of the form: + # "static constexpr int Val = 4; TEST(blah) {" + # then "static constexpr int Val = 4;" needs to be included in filler + # text, as it could be used by subsequent tests.) + splits = line.split('TEST') + + # There should not be more than one TEST definition on a given line + assert len(splits) <= 2 + + if len(splits) > 1: + if not splits[0].isspace(): + # Only add text to filler if there are non-whitespace charcters + # preceding the TEST definition in the line + filler_text += splits[0] + + # The new line is just the TEST-related line + line = 'TEST' + splits[-1] + + # Add tests and transtion to TestDeclaredWaitingStart state. + # Do not add the line to the test text of the new test case; this + # will be done in either the TestDeclaredWaitingStart state processing + # below or in the InTest state processing below. + tests.append(Testcase(filler_text)) + state = ParseState.TestDeclaredWaitingStart + else: + # Any remaining filler text is added to the running filler_text + # which will be used as the prefix for any new tests, and to the + # suffix of any completed tests + add_filler_text(line) + + if state == ParseState.TestDeclaredWaitingStart: + # We have seen a TEST definition but have not yet seen its opening {. + + if '{' in line: + # The first curly bracket for the TEST definition has been found. + # Advance to state InTests. Do not add the line to the test's text + # or change the curly-brace balance of the test; these will be done + # when processing the state == ParseState.InTest condition below. + state = ParseState.InTest + else: + tests[-1].test += line + + if state == ParseState.InTest: + # We are currently within a TEST definition. + # Process lines character-by-character looking for opening and closing + # braces. If we reach parity between opening and closing braces, the + # test is considered done. + filler_text_to_add = "" + for char in line: + if not tests[-1].completed: + tests[-1].test += char + if char == '{': + tests[-1].curly_bracket_balance += 1 + elif char == '}': + tests[-1].curly_bracket_balance -= 1 + if tests[-1].curly_bracket_balance == 0: + tests[-1].completed = True + else: + filler_text_to_add += char + + if filler_text_to_add != "" and (not filler_text_to_add.isspace() or '\n' in filler_text_to_add): + add_filler_text('\n' + filler_text_to_add) + + if tests[-1].completed: + state = ParseState.Filler + + # Write out the new files for tests + filename_prefix, filename_suffix = filename.split('.') + for i, test in enumerate(tests): + assert test.completed + new_filename = filename_prefix + '_' + str(i).zfill(3) + '.' + filename_suffix + full_new_filename = os.path.join(args.dst_dir, new_filename) + + # Replace any '\' with '/'. CMake doesn't like '\'. + full_new_filename = full_new_filename.replace('\\', '/') + + with open(full_new_filename, 'w') as outfile: + outfile.write(test.prefix + test.test + test.suffix) + cmake_src_list.append(full_new_filename) + + +for cmake_file in cmake_src_list: + print(cmake_file) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/testing/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13a9d78dea58a6492183f9ddc50f1510a679cbe6 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/testing/__init__.py @@ -0,0 +1,4 @@ +from . import bench, numeric, utils +from .bench import * +from .numeric import * +from .utils import * diff --git a/build/torch29-cxx11-cu129-x86_64-linux/testing/bench.py b/build/torch29-cxx11-cu129-x86_64-linux/testing/bench.py new file mode 100644 index 0000000000000000000000000000000000000000..2c752da2d3bb0aba7e03ef1921428432b396917a --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/testing/bench.py @@ -0,0 +1,137 @@ +import os +import sys +import torch + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests / 1e3 + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False): + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + profiler.step() + + # Parse the profiling table + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num if total_num > 0 else 0) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/testing/numeric.py b/build/torch29-cxx11-cu129-x86_64-linux/testing/numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..a42c4318db47593c47a4ea89fbdbcb1ffb5cd30e --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/testing/numeric.py @@ -0,0 +1,21 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/build/torch29-cxx11-cu129-x86_64-linux/testing/utils.py b/build/torch29-cxx11-cu129-x86_64-linux/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d202d4192ed385f986ac5cc216acc69378d8ea9 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/build/torch29-cxx11-cu129-x86_64-linux/utils/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f859a20726fcc0ea32c54ed8df37b19b3960a4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/utils/__init__.py @@ -0,0 +1,3 @@ +from . import math, layout +from .layout import * +from .math import * diff --git a/build/torch29-cxx11-cu129-x86_64-linux/utils/layout.py b/build/torch29-cxx11-cu129-x86_64-linux/utils/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bc29d9aaae296a83b8c3546b832a083ade6b28 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/utils/layout.py @@ -0,0 +1,25 @@ +from .._ops import ops + + +def get_mk_alignment_for_contiguous_layout(): + return ops.get_mk_alignment_for_contiguous_layout() + + +def get_tma_aligned_size(mn: int, element_size: int): + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks) + + +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/build/torch29-cxx11-cu129-x86_64-linux/utils/math.py b/build/torch29-cxx11-cu129-x86_64-linux/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..c65026e54b87faf34b498d14d3f81a94759615f4 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/utils/math.py @@ -0,0 +1,107 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + padded_n = align(n, gran_k) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 + m, n = x.shape + x_view = x.view(-1, gran_k, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code # uint8, 0..15 + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + assert n % 2 == 0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 + return packed[:, :n // 2].contiguous(), sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..5f42109466887b22d1b45da0dcc4e3b89d856cea --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b3d7efd79f90bef6a17ba40e9c00f0db399b73f9c4a281b6eab2cf6aadd6041 +size 2875040 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py index 9ef78b878692ffa131b9fd2dda7920174466c7cf..a1598d972cf45498451b415d7a5869693e9b4a96 100644 --- a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _deep_gemm_cuda_5c4fc0a -ops = torch.ops._deep_gemm_cuda_5c4fc0a +from . import _deep_gemm_cuda_a68a39f +ops = torch.ops._deep_gemm_cuda_a68a39f def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_deep_gemm_cuda_5c4fc0a::{op_name}" + return f"_deep_gemm_cuda_a68a39f::{op_name}"